# Survival analysis

## Setup

In [None]:
num_cores = 4

import pickle
import numpy as np
import pandas as pd

from scipy.stats import ttest_ind, mannwhitneyu, ks_2samp, spearmanr, normaltest, shapiro
import statsmodels.api as sm
from statsmodels.stats.multitest import multipletests

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn import linear_model
from sklearn.utils import resample
from sklearn import metrics
from sklearn.metrics import roc_curve,precision_recall_curve

sns.set_palette("colorblind")

dir_="../../data/"
img_dir_='../../docs/imgs/'

## Survival Non-dedup Data 

In [None]:
X_all_proteins = pd.read_csv(dir_+'integrated_X_raw_all_proteins.csv',index_col=0)
display(X_all_proteins.head())
print(X_all_proteins.shape)
joined = pd.read_csv('../../data/mortality_X_y.csv',index_col=0)
display(joined.head())
print(joined.shape)

cov_df = joined.loc[:,
                   ['Cohort_Columbia','Cohort_Cedar','Cohort_Paris']
                  ]
display(cov_df.head())
print(cov_df.shape)

idmap_sub = pd.read_csv('../../data/protein_gene_map_full.csv')[['Protein','Gene_name']].dropna()
display(idmap_sub.head())

proteins = pickle.load(open(dir_+'proteins_no_immunoglobulins.pkl','rb'))
ig_proteins = pickle.load(open(dir_+'proteins_immunoglobulins.pkl','rb'))
print(len(proteins))
print(len(ig_proteins))

## Comparing And Averaging Replicates for Paris Samples

In [None]:
bioreps = [['Pi-13-3-128N', 'Pi-13-4-126'],
       ['Pi-14-4-127C', 'Pi-14-4-128N'], 
           ['Pi-17-1-128N','Pi-17-3-129C'],
           ['Pi-18-1-128N', 'Pi-18-4-129C'],
       ['Pi-19-1-129C', 'Pi-19-4-129N'], 
           ['Pi-20-1-129N','Pi-20-4-130C'], 
           ['Pi-21-1-130C', 'Pi-21-4-130N']]

In [None]:
for reps in bioreps:
    ax,fig = plt.subplots(dpi=300)
    sns.scatterplot(x=reps[0],y=reps[1],
                    data=X_all_proteins.T.loc[proteins])
    plt.savefig(img_dir_+reps[0]+'_vs_'+reps[1]+'.pdf')
    
    ax,fig = plt.subplots(dpi=300)
    sns.scatterplot(x=reps[0],y=reps[1],
                    data=X_all_proteins.T.loc[ig_proteins])
    plt.savefig(img_dir_+reps[0]+'_vs_'+reps[1]+'_ig_proteins.pdf')

### Protein correlations between biological replicates

In [None]:
proteins = pickle.load(open(dir_+'proteins_no_immunoglobulins.pkl','rb'))
X_all_proteins.loc[:,proteins]

In [None]:
corr_df = \
X_all_proteins.loc[:,proteins].T.corr(method='spearman')
lowertribool = np.tril(np.ones(corr_df.shape)).astype(np.bool)
corr_df = corr_df.where(lowertribool)
corr_df_melt = corr_df.stack().rename_axis(['Sample1','Sample2']).reset_index()
corr_df_melt.columns = ['Sample1','Sample2','Value']
corr_df_melt = corr_df_melt.query('Sample1!=Sample2')
corr_df_melt

In [None]:
bioreps = ['Pi-13-3-128N', 'Pi-13-4-126',
       'Pi-14-4-127C', 'Pi-14-4-128N', 'Pi-18-1-128N', 'Pi-18-4-129C',
       'Pi-19-1-129C', 'Pi-19-4-129N', 'Pi-20-1-129N',
       'Pi-20-4-130C', 'Pi-21-1-130C', 'Pi-21-4-130N']
diedbioreps = ['Pi-19-1-129C', 'Pi-19-4-129N']
notbioreps = np.setdiff1d(X_all_proteins.index,bioreps)

In [None]:
corr_df_melt.query('Sample1 in @diedbioreps').sort_values('Value')

In [None]:
print(corr_df_melt.query('Value>0.5').shape[0])
print(corr_df_melt.query('Sample1 in @bioreps & Value>0.5').shape[0])
print(corr_df_melt.shape[0])
corr_df_melt.query('Value>0.5').shape[0]/corr_df_melt.shape[0]

### Averaging Replicates for Paris Samples

In [None]:
print(X_all_proteins.shape)
print(joined.shape)
for reps in bioreps:
    X_all_proteins.loc[reps[0]] = X_all_proteins.loc[reps].mean()
    X_all_proteins = X_all_proteins.drop(reps[1],axis=0)
    joined = joined.drop(reps[1],axis=0)
print(X_all_proteins.shape)
print(joined.shape)

In [None]:
X_all_proteins.to_csv(dir_+'integrated_X_raw_all_proteins_dedup.csv')
display(X_all_proteins.head())
print(X_all_proteins.shape)
joined.to_csv('../../data/mortality_X_y_dedup.csv')

## Survival Dedup Data

In [None]:
X_all_proteins = pd.read_csv(dir_+'integrated_X_raw_all_proteins_dedup.csv',index_col=0)
display(X_all_proteins.head())
print(X_all_proteins.shape)
joined = pd.read_csv('../../data/mortality_X_y_dedup.csv',index_col=0)
display(joined.head())
print(joined.shape)

cov_df = joined.loc[:,
                   ['Cohort_Columbia','Cohort_Cedar','Cohort_Paris']
                  ]
display(cov_df.head())
print(cov_df.shape)

idmap_sub = pd.read_csv('../../data/protein_gene_map_full.csv')[['Protein','Gene_name']].dropna()
display(idmap_sub.head())

proteins = pickle.load(open(dir_+'proteins_no_immunoglobulins.pkl','rb'))
ig_proteins = pickle.load(open(dir_+'proteins_immunoglobulins.pkl','rb'))
print(len(proteins))
print(len(ig_proteins))

## Protein distribution (Figure S2)

In [None]:
cumc = \
X_all_proteins.loc[
    joined[joined.Cohort_Columbia==1].index.values
].T
cedar = \
X_all_proteins.loc[
    joined[joined.Cohort_Cedar==1].index.values
].T
paris = X_all_proteins.loc[
    joined[joined.Cohort_Paris==1].index.values
].T

In [None]:
cumc_df = (cumc.
          rename_axis('Protein').
          loc[ig_proteins].
          apply(lambda x : (x - np.mean(x)) / np.std(x),axis=1).
          reset_index().
          melt(id_vars='Protein'))
cumc_df['Cohort'] = 'Columbia'

cedar_df = (cedar.
          rename_axis('Protein').
          loc[ig_proteins].
          apply(lambda x : (x - np.mean(x)) / np.std(x),axis=1).
          reset_index().
          melt(id_vars='Protein'))
cedar_df['Cohort'] = 'Cedar-Sinai'

paris_df = (paris.
          rename_axis('Protein').
          loc[ig_proteins].
          apply(lambda x : (x - np.mean(x)) / np.std(x),axis=1).
          reset_index().
          melt(id_vars='Protein'))
paris_df['Cohort'] = 'Pitíe Salpetriere'

In [None]:
matplotlib.rcParams['axes.titlepad'] = 8
matplotlib.rcParams['axes.titlesize'] = 16
matplotlib.rcParams['axes.labelsize'] = 16
matplotlib.rcParams['xtick.labelsize'] = 16
matplotlib.rcParams['ytick.labelsize'] = 16
dpi=300

fig,ax = plt.subplots(nrows=3,ncols=1,sharex=True,sharey=True,dpi=dpi,figsize=(6,4))

cohorts=['Columbia','Cedar-Sinai','Pitíe Salpetriere']

for i,grp in cumc_df.groupby('Sample'):
    sns.distplot(grp['value'],
                 color='Blue',
                 label=cohorts[0],
                 kde=False,
                 ax=ax[0])
    ax[0].set_alpha(0.8)

for i,grp in cedar_df.groupby('Sample'):
    sns.distplot(grp['value'],
                 color='Green',
                 label=cohorts[1],
                 kde=False,
                 ax=ax[1])
    ax[1].set_alpha(0.8)

for i,grp in paris_df.groupby('Sample'):
    sns.distplot(grp['value'],
                 color='red',
                 label=cohorts[2],
                 kde=False,
                 ax=ax[2])
    ax[2].set_alpha(0.8)
sns.despine()
ax[0].set_xlabel('')
ax[1].set_xlabel('')

for i,a in enumerate(ax):
    a.text(2.5,50,cohorts[i])
    a.set_xlim(-5,5)
ax[1].set_ylabel('Density')
ax[1].yaxis.set_label_coords(-0.1,0)
ax[0].set_title('Exosome protein expression distribution')
ax[2].set_xlabel('Standardized protein expression')

fig.tight_layout()
fig.savefig(dir_+'ProteinDescription_distributions.pdf')

In [None]:
pd.DataFrame({
    'normaltest' : [normaltest(cumc_df.value.values),
                   normaltest(cedar_df.value.values),
                   normaltest(paris_df.value.values)],
    'shapiro' : [shapiro(cumc_df.value.values),
                   shapiro(cedar_df.value.values),
                   shapiro(paris_df.value.values)]
},index=['cumc','cedar','paris'])

In [None]:
print(
ks_2samp(cumc_df.value.values,cedar_df.value.values)
)
print(
ks_2samp(cumc_df.value.values,paris_df.value.values)
)
print(
ks_2samp(cedar_df.value.values,paris_df.value.values)
)

## Protein distribution tests

In [None]:
X_all_proteins = pd.read_csv(dir_+'integrated_X_raw_all_proteins_dedup.csv',index_col=0)
joined = pd.read_csv('../../data/mortality_X_y_dedup.csv',index_col=0)

X = X_all_proteins.copy().apply(lambda x : (x - x.min()) / (x.max() - x.min()))
Y = joined[['expired']][['expired']].copy()
Y = Y.loc[X.index.values]
Y = (Y==0).astype(int)
Y.columns = ['Survival']

In [None]:
prots = X.columns.values
lst = []
for prot in prots:
    print(prot)
    if idmap_sub.query('Protein==@prot').shape[0]==0: continue
    gene = idmap_sub.query('Protein==@prot').Gene_name.values[0]

    x = X_all_proteins.loc[Y[(Y==1).values].index.values,prot].values
    y = X_all_proteins.loc[Y[(Y==0).values].index.values,prot].values

    fig,ax=plt.subplots(figsize=(8,6),dpi=300)
    sns.swarmplot(
        'Survival',
        'expr',
        data=pd.concat(
        [
            pd.DataFrame({'expr' : x,'Survival' : 'Survived'}),
            pd.DataFrame({'expr' : y,'Survival' : 'Died'})
        ])
    )
    ax.set_xlabel(None)
    ax.set_ylabel('Protein expression')
    ax.set_title(
        gene+
        '\nT test statistic = '+str(np.round(ttest_ind(x,y)[0],3))+
        '; Mann Whitney Pvalue = '+str(np.round(mannwhitneyu(x,y)[1],3))
    )
    fig.savefig('../../docs/imgs/survival_protein_tests/'+gene+'_test.png')
    plt.close()
    lst.append([mannwhitneyu(x,y)[0]/(len(x)*len(y)),np.var(x),np.var(y),gene])

## Differential analysis functions

In [None]:
def balanced_resample(Y,seed=42):
	"""
	Given a binary pandas series, resample after balancing for equal size of classes
	"""
	
	Y.sort_values(inplace=True)
	num_to_sample = Y.value_counts().min()
	
	dfs = []
	for grp in Y.unique():
		y = Y[Y==grp].head(num_to_sample)
		dfs.append(resample(y,random_state=seed))  
	
	return pd.concat(dfs)

def pull_logit_coefficients(fit):
	
	return fit.coef_[0][0]

def coef_to_prob(coef):
	
	odds = np.exp(coef)
	prob = odds/(1 + odds)
	
	return prob

def coef_to_odds(coef):
	
	odds = np.exp(coef)
	
	return odds

def prediction(X,Y,model,seed=42):
	"""
	Given a feature matrix and binary class series, 
	balance then resample y (depends on balanced_resample),
	predict and grab logistic regression coefficients,
	convert and return probability.
	
	"""
	Y_balanced = resample(Y,random_state=seed)
	X_balanced = X.loc[Y_balanced.index]

	fit = model.fit(X_balanced,Y_balanced.values.reshape(-1,1))

	coef = pull_logit_coefficients(fit)
	
	return coef_to_odds(coef)

def balanced_prediction(X,Y,model,seed=42):
	"""
	Given a feature matrix and binary class series, 
	balance then resample y (depends on balanced_resample),
	predict and grab logistic regression coefficients,
	convert and return probability.
	
	"""
	Y_balanced = balanced_resample(Y,seed=seed)
	X_balanced = X.loc[Y_balanced.index]
	
	fit = model.fit(X_balanced,Y_balanced)
	
	coef = pull_logit_coefficients(fit)
	
	return coef_to_odds(coef)

def bootstrap_prediction_transformations(odds_boot,var='variable'):
	df = pd.DataFrame([
		[key for key in odds_boot.keys()],
		[np.median(odds_boot[key]) for key in odds_boot.keys()]
	],
		index=[var,'bootstrap_median']
	).T
	sorted_df = df.sort_values(['bootstrap_median'],ascending=[False])	
	output = (pd.DataFrame.from_dict(odds_boot).
			  reset_index().rename(columns={'index' : 'bootstrap'}).
			  melt(id_vars='bootstrap',var_name=var,value_name='odds').
			  set_index(var).
			  join(sorted_df.set_index(var))
			 )
	odds_wcov_boot = output.reset_index().copy()
	variables = odds_wcov_boot[var].unique()
	err = {}
	for p in variables:
		q = '{} == "{}"'.format(var,p)
		lwr = odds_wcov_boot.query(q).odds.quantile(.025)
		mean = odds_wcov_boot.query(q).odds.mean()
		median = odds_wcov_boot.query(q).odds.quantile(.5)
		upr =odds_wcov_boot.query(q).odds.quantile(.975)
		err[p] = [lwr,mean,median,upr]
	err_df = pd.DataFrame(err,index=['lwr','mean','median','upr']).T.rename_axis(var)
	return output, err_df

from joblib import Parallel, delayed

def bootstrap_of_fcn(func=None,params={},n_jobs=4,nboot=2):	
	if func==None:
		return "Need fcn to bootstrap"
	parallel = Parallel(n_jobs=n_jobs)
	return parallel(delayed(func)(seed=k,**params) for k in range(nboot))


## Survival multivariate logistic regression

In [None]:
X_all_proteins = pd.read_csv(dir_+'integrated_X_raw_all_proteins_dedup.csv',index_col=0)
joined = pd.read_csv('../../data/mortality_X_y_dedup.csv',index_col=0)

In [None]:
X = X_all_proteins.copy().apply(lambda x : (x - x.min()) / (x.max() - x.min()))
Y = joined[['expired']][['expired']].copy()
Y = Y.loc[X.index.values]
y = (Y==0).astype(int)
y.columns = ['Survival']
y.index.name = 'index'

In [None]:
pd.set_option('display.float_format', lambda x: '%.5f' % x)
vars_=\
['Age','BMI','CVP','CVP/PCWP','Creatinine',
 'INR','Ischemic_Time','Blood_Type_A','Blood_Type_B',
'Blood_Type_O',
 #'Blood_Type_AB',
 'Donor_Age',
#'Radial_Score',
 'Sodium','MELDXI',
'PA_Diastolic','PA_Mean','PA_Systolic','PCWP','TBILI',
'Antiarrhythmic_Use_Y','Beta_Blocker_Y','Cardiomyopathy_Adriamycin',
'Cardiomyopathy_Amyloid',
 #'Cardiomyopathy_Chagas',
 'Cardiomyopathy_Congenital',
#'Cardiomyopathy_Hypertrophic cardiomyopathy','Cardiomyopathy_Idiopathic',
'Cardiomyopathy_Ischemic','Cardiomyopathy_Myocarditis',
#'Cardiomyopathy_Valvular Heart Disease','Cardiomyopathy_Viral',
'Diabetes_Y','History_Of_Tobacco_Use_Y','Mechanical_Support_Y',
'Prior_Inotrope_Y','PGD'
 #'Sex_F'
]
X = joined[vars_]

In [None]:
vars_=\
['Age','BMI','CVP','CVP/PCWP','Creatinine',
 'INR','Ischemic_Time','Blood_Type_A','Blood_Type_B',
'Blood_Type_O',
 #'Blood_Type_AB',
 'Donor_Age',
#'Radial_Score',
 'Sodium','MELDXI',
'PA_Diastolic','PA_Mean','PA_Systolic','PCWP','TBILI',
'Antiarrhythmic_Use_Y','Beta_Blocker_Y','Cardiomyopathy_Adriamycin',
'Cardiomyopathy_Amyloid',
 #'Cardiomyopathy_Chagas',
 'Cardiomyopathy_Congenital',
#'Cardiomyopathy_Hypertrophic cardiomyopathy','Cardiomyopathy_Idiopathic',
'Cardiomyopathy_Ischemic','Cardiomyopathy_Myocarditis',
#'Cardiomyopathy_Valvular Heart Disease','Cardiomyopathy_Viral',
'Diabetes_Y','History_Of_Tobacco_Use_Y','Mechanical_Support_Y',
'Prior_Inotrope_Y','PGD'
 #'Sex_F'
]
with open(dir_+'prediction_clinical_variables.pickle', 'wb') as handle:
    pickle.dump(vars_, handle, protocol=pickle.HIGHEST_PROTOCOL)

In [None]:
log_reg = sm.Logit(y,X).fit()
#https://stackoverflow.com/questions/51734180/converting-statsmodels-summary-object-to-pandas-dataframe
multi = pd.read_html(log_reg.summary().tables[1].as_html(), header=0, index_col=0)[0]
multi.columns=['coef','std_err','z','pvalue','lwr','upr']
print(multi.shape)
display(multi.sort_values('pvalue'))
print(multi.query('pvalue<0.05').shape)
print(multi.query('lwr>0 | upr<0').shape)
multi.to_csv(dir_+'survival_clinical_multivariate_model_statistics.csv')

## GSEA rank statistic (univariate, bootstrapped logistic regression)

In [None]:
X_all_proteins = pd.read_csv(dir_+'integrated_X_raw_all_proteins_dedup.csv',index_col=0)
joined = pd.read_csv('../../data/mortality_X_y_dedup.csv',index_col=0)
cov_df = joined.loc[:,
                   ['Cohort_Columbia','Cohort_Cedar','Cohort_Paris']
                  ]
proteins = pickle.load(open('../../data/proteins_immunoglobulins.pkl','rb'))

In [None]:
X = X_all_proteins.copy().apply(lambda x : (x - x.min()) / (x.max() - x.min()))
Y = joined[['expired']][['expired']].copy()
Y = Y.loc[X.index.values]
Y = (Y==0).astype(int)
Y.columns = ['Survival']

In [None]:
C=1
seed = 42
tol=1e-3

model = {"Logistic Regression" : 
linear_model.LogisticRegression(
	C=C,
	penalty='l1',
	solver="liblinear",
	random_state=seed,fit_intercept=True
	)
		 }




In [None]:
nboot=200

boots = []
for prot in proteins:
	params = {
	'X' : X[[prot]].join(cov_df),
	'Y' : Y,
	'model' : model['Logistic Regression']}
	lst = bootstrap_of_fcn(func=prediction,params=params,n_jobs=num_cores,nboot=nboot)
	boots.append(lst)

odds_boot = {}
for i,prot in enumerate(proteins):
	odds_boot[prot] = boots[i]

output, err_df = bootstrap_prediction_transformations(odds_boot)



In [None]:

dat = \
(
    output.
    reset_index().
    groupby('variable')['odds'].
    mean().
    sort_values().
    reset_index().
    set_index('variable').
    join(
        X_all_proteins.
        loc[Y.index[Y.Survival==1]].
        reset_index().
        melt(id_vars='index').
        groupby('variable')['value'].
        mean().
        reset_index().
        set_index('variable')
    ).
    join(
    idmap_sub.set_index('Protein'))
)
fig,ax = plt.subplots(dpi=300)
sns.scatterplot('odds','value',data=dat,ax=ax)
ax.set_ylabel('Average protein expression\nfrom patients who died')
ax.set_xlabel('Average odds association to mortality')
display(dat.query('odds<0.5'))
display(dat.query('odds>2'))

In [None]:
fig,ax = plt.subplots(dpi=300,figsize=(12,4))
ax = output.reset_index().groupby('variable')['odds'].mean().sort_values().plot(ax=ax)

In [None]:
univariate = dat.copy()
univariate[['model']] = 'univariate'
univariate

In [None]:
(
    univariate.
    reset_index().
    set_index('Gene_name').
    loc[:,['odds','variable']].
    sort_values('odds',ascending=False).
    drop_duplicates().
    to_csv(dir_+'survival_rank_statistic.csv')
)

## Prediction processing functions

In [None]:
def generate_val_scores(ppred_df,
                        scorer = { 
                            'roc_auc' : metrics.roc_auc_score,  
                            'ppv' : metrics.average_precision_score,
                            'npv' : metrics.average_precision_score
                        }):
    bin_names = ['accuracy']
    n_names = ['npv','specificity']
    score_names = scorer.keys()
    vals = []
    for set_ in ppred_df.index.unique():
        tmp = ppred_df.loc[set_]
        for b in range(50):
            sub = tmp.sample(tmp.shape[0],replace=True,random_state=b)
            arr = \
            [scorer[x](sub.y_true,sub.y_proba) 
             if x not in n_names else scorer[x](sub.y_true,sub.y_proba,pos_label=0) 
             for x in score_names]
            arr.extend([set_,b])
            vals.append(
                arr
            )
    cols = [x for x in score_names]
    cols.extend(['set','bootstrap'])
    val_df = \
    (
        pd.DataFrame(
            vals,
            columns=cols
        ).
        groupby('set')[[x for x in score_names]].
        describe(
            percentiles=[0.025,0.975]
        )
    )
    val_df.columns = \
    [x[0]+'_'+x[1].replace('2.5%','lwr').replace('97.5%','upr') for x in val_df.columns]
    return (
        val_df.
        loc[:,np.concatenate([[x+'_lwr',x+'_mean',x+'_upr'] for x in score_names])]
    )
def get_pperf_roc_curve_stats(dat,n=50):
    
    tups = []
    for b in range(n):
        x = (dat.
             sample(n=dat.shape[0],replace=True,random_state=b)
            )
        f,t,th = roc_curve(x.y_true,x.y_proba)

        tups.append(
            pd.DataFrame({ 'fpr' : f,
                          'tpr' : t,
                          't' : th
                         }
                        )
        )

    tmp = pd.concat(tups).groupby('t').mean()
    fpr = tmp['fpr'].values
    tpr = tmp['tpr'].values
    return fpr,tpr
    
def get_pperf_precision_recall_curve_stats(dat,n=50):
    
    tups = []
    for b in range(n):
        x = (dat.
             sample(n=dat.shape[0],replace=True,random_state=b)
            )
        r,p,th = precision_recall_curve(x.y_true,x.y_proba)
        r = list(r)
        p = list(p)
        r.pop()
        p.pop()
        tups.append(
            pd.DataFrame({ 'precision' : p,
                          'recall' : r,
                          't' : th
                         }
                        )
        )

    tmp = pd.concat(tups).groupby('t').mean()
    p = tmp['precision'].tolist()
    r = tmp['recall'].tolist()
    p[0] = 1
    r[0] = 0
    return p,r

def plt_atts_roc(ax,fig):

    ax.set_xlim(-0.01,1.01)
    ax.set_ylim(-0.01,1.01)

    lims = [
        np.min([ax.get_xlim(), ax.get_ylim()]),  # min of both axes
        np.max([ax.get_xlim(), ax.get_ylim()]),  # max of both axes
    ]

    # now plot both limits against eachother
    ax.plot(lims, lims, 'r--', alpha=0.75, zorder=0)

    #ax.set_ylabel('Sensitivity',size=18)
    #ax.set_xlabel('1 - Specificity',size=18)
    ax.tick_params(axis='both', which='major', labelsize=14)

    fig.tight_layout()
    
    return fig

def plt_atts_pr(ax,fig):

    ax.set_xlim(-0.01,1.01)
    ax.set_ylim(-0.01,1.01)

    lims = [
        [np.min(ax.get_xlim()), np.max(ax.get_ylim())],  
        [np.max(ax.get_xlim()), np.min(ax.get_ylim())]
    ]

    # now plot both limits against eachother
    ax.plot(lims[0], lims[1], 'r--', alpha=0.75, zorder=0)

    #ax.set_ylabel('Precision',size=18)
    #ax.set_xlabel('Recall',size=18)
    ax.tick_params(axis='both', which='major', labelsize=14)

    fig.tight_layout()
    
    return fig

## Survival predictions and Volcano plot (Figures 2 and 3)

In [None]:
ppred_df = pd.read_csv(
                dir_+'mortality_predictions_marker_patient_predictions_survival.csv',
                index_col=0)
arr_dfs = []
for set_,tmp in ppred_df.groupby('set'):
    vals = []
    for b in range(50):
        sub = tmp.sample(tmp.shape[0],replace=True,random_state=b)
        conf = metrics.confusion_matrix(sub.y_true,sub.y_pred)
        tn = conf[0][0]
        tp = conf[1][1]
        fn = conf[1][0]
        fp = conf[0][1]
        auc = metrics.roc_auc_score(sub.y_true,sub.y_proba)
        vals.append([auc,tp,tn,fp,fn,
                     np.mean(sub.y_proba[sub.y_true==1]),
                     np.mean(sub.y_proba[sub.y_true==0])])
    arr = np.array(vals)
    arr_df = \
    pd.DataFrame(
        np.mean(arr,0),
        index=['AUROC','TP','TN','FP','FN','mean_case_proba','mean_ctrl_proba'],
        columns=[set_]
    ).T
    arr_dfs.append(
        arr_df.eval(
            '''
            sensitivity = (TP/(TP+FN))
            PPV = (TP/(TP+FP))
            specificity = (TN/(TN+FP))
            NPV = (TN/(TN+FN))
            '''
        )
    )

perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival.csv'
)
perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in perf_df.set_features]

tmp = (
    pd.
    concat(arr_dfs).
    join(
        perf_df.
        set_index('set').
        loc[:,['set_features']]
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.
        set_index('Protein')
    ).
    sort_values('AUROC')
)
tmp.sort_values('AUROC').tail(10)

In [None]:
perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival.csv'
)
perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in perf_df.set_features]

val_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_patient_predictions_survival.csv',
    index_col=0)

perf_df = \
(
    perf_df.
    loc[:,['set','set_features']].
    set_index('set').
    join(
        generate_val_scores(
            val_df.set_index('set')
        )
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.set_index('Protein')
    ).
    sort_values('roc_auc_mean')
)

fimp_df = pd.read_csv(dir_+'mortality_predictions_marker_feature_importance_survival.csv')
fimp_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in fimp_df.set_features]

In [None]:
full_fimp_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_full_feature_importance_survival.csv',
    index_col=0
)
full_perm_fimp_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_full_permuted_feature_importance_survival.csv',
    index_col=0
)
proteins = np.intersect1d(full_fimp_df.Feature.unique(),full_perm_fimp_df.Feature.unique())
sig_lst = []
for prot in proteins:
    x = full_fimp_df[full_fimp_df.Feature==prot].Importance.values
    y = full_perm_fimp_df[full_fimp_df.Feature==prot].Importance.values
    lwr, upr = (
        full_perm_fimp_df[
            full_fimp_df.Feature==prot
        ].
        Importance.
        describe(
            percentiles=[0.025,0.975]
        ).
        loc[['2.5%','97.5%']]
    )
    stat, pv = ks_2samp(x,y)
    sig_lst.append(
        pd.DataFrame(
            {'statistic' : stat,'pvalue' : pv,'perm_fimp_lwr' : lwr,'perm_fimp_upr' : upr},
            index=[prot]
        )
    )
sig_df = pd.concat(sig_lst)
sig_df['padj'] = multipletests(sig_df['pvalue'],method='bonferroni')[1]
sig_df

In [None]:
dat = \
(
    perf_df.
    loc[:,['roc_auc_mean','roc_auc_lwr','roc_auc_upr','set','Gene_name']].
    rename(
        columns={
            'roc_auc_mean' : 'perf',
            'roc_auc_lwr' : 'perf_lwr',
            'roc_auc_upr' : 'perf_upr'
        }
    ).
    reset_index().
    set_index('set').
    join(
        fimp_df.
        query('Feature!="Intercept"').
        set_index('set').
        loc[:,['mean','2.5%','97.5%']].
        rename(columns={'mean' : 'fimp','2.5%' : 'fimp_lwr','97.5%' : 'fimp_upr'})
    ).
    set_index('set_features').
    join(sig_df)
)
dat['-log10pvalue'] = -np.log10(dat['pvalue'])

dat['Marker'] = 'blue'
dat.loc[dat.dropna().index.values,'Marker'] = 'red'

dat = dat.query('set_features!="expired"')
dat = dat.query('set_features!="Survival"')

In [None]:
dat.to_csv(dir_+'mortality_survival_marker_predictions.csv')

In [None]:
display(dat.sort_values('perf').tail(10))
display(dat.sort_values('fimp').head(10))
display(dat.sort_values('fimp').tail(10))

dat2 = dat.query('set_features!="PGD"')

palette = 'RdBu_r'
fig,ax = plt.subplots(dpi=300)
plot = plt.scatter(dat2['fimp'].values,
          dat2['perf'].values,
          c=dat2['-log10pvalue'].values,
          cmap=palette)

plt.clf()
plt.colorbar(plot)

ax = sns.scatterplot('fimp','perf',hue='-log10pvalue',data=dat2,
                    style='Marker',
                     style_order=dat2.sort_values('Marker').Marker.unique(),
                     edgecolor='k',palette=palette)
ax.set_xlabel(r'$\beta$ coefficient',size=20)
ax.set_ylabel('AUROC',size=20)
ax.legend_.remove()

fig.tight_layout()

fig.savefig('../../docs/imgs/mortality_mccv_predictions_survival.png')

In [None]:
(
    dat.
    query(
        'perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & (fimp_lwr>0 | fimp_upr<0) & padj<0.0001'
    )
)

In [None]:
(
    dat.
    query(
        'perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & (fimp_lwr>0 | fimp_upr<0) & padj<0.0001'
    )
).shape

In [None]:
(
    dat.
    query(
        'perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & (fimp_lwr>0 | fimp_upr<0) & padj<0.0001'
    ).
    loc[:,
        ['perf_lwr',
         'perf','perf_upr','fimp_lwr',
         'fimp','fimp_upr','padj','Gene_name'
        ]
       ].
    sort_values('fimp',ascending=False)
).to_csv(dir_+'mortality_significant_prediction_stats_survival.csv')
(
    dat.
    reset_index().
    query('fimp_lwr>0').
    loc[:,
        ['set_features','perf_lwr',
         'perf','perf_upr','fimp_lwr',
         'fimp','fimp_upr'
        ]
       ].
    set_index('set_features').
    sort_values('fimp',ascending=False)
).to_csv(dir_+'mortality_prediction_stats.csv')

In [None]:
ppred_df = \
(
    pd.read_csv(
        dir_+
        'mortality_predictions_marker_patient_predictions_survival.csv',
        index_col=0).
    set_index('set').
    join(
        fimp_df[['set','Feature']].
        query('Feature!="Intercept"').
        drop_duplicates().
        set_index('set')
    ).
    reset_index().
    set_index('Feature').
    join(idmap_sub.set_index('Protein'))
)
tmp = ppred_df[['Sample','y_proba']].set_index('Sample').join(Y)
tmp = \
pd.concat([
    pd.DataFrame(
        [tmp.query('Survival==1').sample(100).y_proba.values,np.repeat('Survived',100)]
    ).T,
    pd.DataFrame(
        [tmp.query('Survival==0').sample(100).y_proba.values,np.repeat('Died',100)]
    ).T
])
tmp.columns = ['y_proba','Survival']
tmp.y_proba = tmp.y_proba.astype(float)
display(tmp)
fig,ax = plt.subplots(dpi=300)
sns.kdeplot(data=tmp,x='y_proba',hue='Survival',ax=ax)

In [None]:
dat = dat.query('set_features!="PGD"')
feats=dat.query('perf>0.6').index.values
print(feats)
print([idmap_sub.query('Protein==@feat')['Gene_name'].values[0] 
       for feat in feats])

In [None]:
default = plt.rcParams["font.weight"]
plt.rcParams["font.weight"]='bold'
dpi=300
survived = (joined['expired']==0)
for feat in feats:
    gene = idmap_sub.query('Protein==@feat')['Gene_name'].values[0]
    
    pperf=ppred_df.loc[feat]
    c='black'
    
    func=get_pperf_roc_curve_stats
    fpr,tpr = func(pperf)

    fig,ax = plt.subplots(dpi=dpi)
    ax.plot(fpr,tpr,c=c)
    ax.plot(fpr,tpr,c=c,marker='o',mec=c,ms=1,lw=0.00001)
    ax.set_xlabel('')
    fig = plt_atts_roc(ax,fig)
    fig.savefig('../../docs/imgs/'+feat+'_'+gene+'_roc_curve.png')
    
    func=get_pperf_precision_recall_curve_stats
    fpr,tpr = func(pperf)

    fig,ax = plt.subplots(dpi=dpi)
    ax.plot(fpr,tpr,c=c)
    ax.plot(fpr,tpr,c=c,marker='o',mec=c,ms=1,lw=0.00001)
    fig = plt_atts_pr(ax,fig)
    fig.savefig('../../docs/imgs/'+feat+'_'+gene+'_pr_curve.png')
    
    prot=feat
    df = X_all_proteins.apply(lambda x : (x - min(x)) / (max(x) - min(x)),axis=0)[[prot]]
    dat = df.join(
        survived.map({True : 'Survived',False : 'Died'})
    ).join(joined.PGD.map({1 : 'PGD',0 : 'nonPGD'}))
    fig,ax=plt.subplots(figsize=(5,3),dpi=300)
    sns.boxplot(x='expired',y=prot,data=dat,color='lightgray',ax=ax)
    sns.swarmplot(
        x='expired',y=prot,
        data=dat,
        color="black",
        edgecolor='black',ax=ax
    )
    ax.set_xlabel('')
    ax.set_ylabel('')
    ax.set_yticklabels(ax.get_yticks(),fontsize=12)
    ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.1f'))
    ax.set_xticklabels(ax.get_xticklabels(),fontsize=12)

    fig.savefig('../../docs/imgs/'+prot+'_'+gene+'_boxplot.png')
    plt.close()
plt.rcParams["font.weight"]=default


## Survival prediction comparison with covariates (Figure S3)

In [None]:
s_perf_df = perf_df
s_fimp_df = fimp_df

In [None]:
swcovs_perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival_wcovs.csv'
)
swcovs_perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in swcovs_perf_df.set_features]

idmap_sub = pd.read_csv('../../data/protein_gene_map_full.csv')[['Protein','Gene_name']].dropna()

swcovs_perf_df = \
(
    swcovs_perf_df.
    loc[:,['set','set_features']].
    set_index('set').
    join(
        generate_val_scores(
            pd.read_csv(
                dir_+'mortality_predictions_marker_patient_predictions_survival_wcovs.csv',
                index_col=0).set_index('set')
        )
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.set_index('Protein')
    ).
    sort_values('roc_auc_mean')
)

swcovs_fimp_df = pd.read_csv(dir_+'mortality_predictions_marker_feature_importance_survival_wcovs.csv')
swcovs_fimp_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in swcovs_fimp_df.set_features]

In [None]:
tmp = \
(
    swcovs_perf_df.
    query('set!=40 & set!=42 & set!=35').
    loc[:,['roc_auc_mean']].
    rename(columns = {'roc_auc_mean' : 'wcovs'}).
    join(
        s_perf_df.
        loc[:,['roc_auc_mean']].
        rename(columns = {'roc_auc_mean' : 'wocovs'})
    )
)
fig,ax=plt.subplots(dpi=300)
g = sns.scatterplot('wcovs','wocovs',data=tmp,ax=ax)
ax.set_xticklabels(ax.get_xticklabels(),fontsize=12,weight='bold')
ax.set_xlabel('Prediction with site-of-origin covariates',fontsize=12,weight='bold')
ax.xaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.2f'))
ax.set_ylabel('Marker prediction',fontsize=12,weight='bold')
ax.set_yticklabels(ax.get_yticks(),fontsize=12,weight='bold')
ax.yaxis.set_major_formatter(matplotlib.ticker.FormatStrFormatter('%.2f'))

x0, x1 = g.axes.get_xlim()
y0, y1 = g.axes.get_ylim()
lims = [max(x0, y0), min(x1, y1)]
g.axes.plot(lims, lims, 'r--')
fig.tight_layout()
fig.savefig('../../docs/imgs/survival_marker_predictions_w_wo_covariates.png')

In [None]:
tmp = \
(
    s_fimp_df.
    query('Feature!="Intercept" & Feature!="Cohort_Columbia" & Feature!="Cohort_Cedar" & Feature!="Cohort_Paris"').
    set_index('set').
    loc[:,['mean']].
    rename(columns={'mean' : 'wcovs'}).
    join(
        swcovs_fimp_df.
        query('Feature!="expired" & Feature!="Survival" & Feature!="PGD"').
        query('Feature!="Intercept" & Feature!="Cohort_Columbia" & Feature!="Cohort_Cedar" & Feature!="Cohort_Paris"').
        set_index('set').
        loc[:,['mean']].
        rename(columns={'mean' : 'wocovs'})
    )
)
fig,ax=plt.subplots(dpi=300)
g = sns.scatterplot('wcovs','wocovs',data=tmp,ax=ax)
ax.set_xlabel('Marker association with site-of-origin covariates',fontsize=12,weight='bold')
ax.set_xticklabels(ax.get_xticks(),fontsize=12,weight='bold')
ax.set_ylabel('Marker association',fontsize=12,weight='bold')
ax.set_yticklabels(ax.get_yticks(),fontsize=12,weight='bold')

x0, x1 = g.axes.get_xlim()
y0, y1 = g.axes.get_ylim()
lims = [max(x0, y0), min(x1, y1)]
g.axes.plot(lims, lims, 'r--')
fig.tight_layout()
fig.savefig('../../docs/imgs/survival_marker_associations_w_wo_covariates.png')

In [None]:
ppred_df = pd.read_csv(
                dir_+'mortality_predictions_marker_patient_predictions_survival_wcovs.csv',
                index_col=0)
arr_dfs = []
for set_,tmp in ppred_df.groupby('set'):
    vals = []
    for b in range(50):
        sub = tmp.sample(tmp.shape[0],replace=True,random_state=b)
        conf = metrics.confusion_matrix(sub.y_true,sub.y_pred)
        tn = conf[0][0]
        tp = conf[1][1]
        fn = conf[1][0]
        fp = conf[0][1]
        auc = metrics.roc_auc_score(sub.y_true,sub.y_proba)
        vals.append([auc,tp,tn,fp,fn,
                     np.mean(sub.y_proba[sub.y_true==1]),
                     np.mean(sub.y_proba[sub.y_true==0])])
    arr = np.array(vals)
    arr_df = \
    pd.DataFrame(
        np.mean(arr,0),
        index=['AUROC','TP','TN','FP','FN','mean_case_proba','mean_ctrl_proba'],
        columns=[set_]
    ).T
    arr_dfs.append(
        arr_df.eval(
            '''
            sensitivity = (TP/(TP+FN))
            PPV = (TP/(TP+FP))
            specificity = (TN/(TN+FP))
            NPV = (TN/(TN+FN))
            '''
        )
    )

perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival_wcovs.csv'
)
perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in perf_df.set_features]

tmp = (
    pd.
    concat(arr_dfs).
    join(
        perf_df.
        set_index('set').
        loc[:,['set_features']]
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.
        set_index('Protein')
    ).
    sort_values('AUROC')
)
tmp.sort_values('AUROC').tail(10)

In [None]:
perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival_wcovs.csv'
)
perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in perf_df.set_features]

val_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_patient_predictions_survival_wcovs.csv',
    index_col=0)

perf_df = \
(
    perf_df.
    loc[:,['set','set_features']].
    set_index('set').
    join(
        generate_val_scores(
            val_df.set_index('set')
        )
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.set_index('Protein')
    ).
    sort_values('roc_auc_mean')
)

fimp_df = pd.read_csv(dir_+'mortality_predictions_marker_feature_importance_survival_wcovs.csv')
fimp_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in fimp_df.set_features]

In [None]:
full_fimp_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_full_feature_importance_survival_wcovs.csv',
    index_col=0
)
full_perm_fimp_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_full_permuted_feature_importance_survival_wcovs.csv',
    index_col=0
)
proteins = np.intersect1d(full_fimp_df.Feature.unique(),full_perm_fimp_df.Feature.unique())
sig_lst = []
for prot in proteins:
    x = full_fimp_df[full_fimp_df.Feature==prot].Importance.values
    y = full_perm_fimp_df[full_fimp_df.Feature==prot].Importance.values
    lwr, upr = (
        full_perm_fimp_df[
            full_fimp_df.Feature==prot
        ].
        Importance.
        describe(
            percentiles=[0.025,0.975]
        ).
        loc[['2.5%','97.5%']]
    )
    stat, pv = ks_2samp(x,y)
    sig_lst.append(
        pd.DataFrame(
            {'statistic' : stat,'pvalue' : pv,'perm_fimp_lwr' : lwr,'perm_fimp_upr' : upr},
            index=[prot]
        )
    )
sig_df = pd.concat(sig_lst)
sig_df['padj'] = multipletests(sig_df['pvalue'],method='bonferroni')[1]
sig_df

In [None]:
dat = \
(
    perf_df.
    loc[:,['roc_auc_mean','roc_auc_lwr','roc_auc_upr','set','Gene_name']].
    rename(
        columns={
            'roc_auc_mean' : 'perf',
            'roc_auc_lwr' : 'perf_lwr',
            'roc_auc_upr' : 'perf_upr'
        }
    ).
    reset_index().
    set_index('set').
    join(
        fimp_df.
        query('Feature!="Intercept"').
        query("Feature!='Cohort_Columbia'").
        query("Feature!='Cohort_Cedar'").
        query("Feature!='Cohort_Paris'").
        set_index('set').
        loc[:,['mean','2.5%','97.5%']].
        rename(columns={'mean' : 'fimp','2.5%' : 'fimp_lwr','97.5%' : 'fimp_upr'})
    ).
    set_index('set_features').
    join(sig_df)
)
dat['-log10pvalue'] = -np.log10(dat['pvalue'])
dat.index.name = 'set_features'
dat = dat.reset_index()
dat

In [None]:
dat.to_csv(dir_+'mortality_survival_marker_predictions_wcovs.csv')

In [None]:
dat_wcovs = dat.copy()
dat_wocovs = pd.read_csv(dir_+'mortality_survival_marker_predictions.csv')
fts = (
    dat_wcovs.
    query(
        'perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & (fimp_lwr>0 | fimp_upr<0) & padj<0.0001'
    ).
    set_features.values
)
(
    dat_wocovs.
    query(
        'perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & (fimp_lwr>0 | fimp_upr<0) & padj<0.0001'
    ).
    query('set_features in @fts').
    sort_values('fimp',ascending=False)
)

## Survival prediction within 1 year

In [None]:
perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival_wn_year.csv'
)
perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in perf_df.set_features]
perf_df = \
(
    perf_df.
    loc[:,['set','set_features']].
    set_index('set').
    join(
        generate_val_scores(
            pd.read_csv(
                dir_+'mortality_predictions_marker_patient_predictions_survival_wn_year.csv',
                index_col=0).set_index('set')
        )
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.set_index('Protein')
    ).
    sort_values('roc_auc_mean')
)

fimp_df = pd.read_csv(dir_+'mortality_predictions_marker_feature_importance_survival_wn_year.csv')
fimp_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in fimp_df.set_features]

In [None]:
full_fimp_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_full_feature_importance_survival_wn_year.csv',
    index_col=0
)
full_perm_fimp_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_full_permuted_feature_importance_survival_wn_year.csv',
    index_col=0
)
proteins = np.intersect1d(full_fimp_df.Feature.unique(),full_perm_fimp_df.Feature.unique())
sig_lst = []
for prot in proteins:
    x = full_fimp_df[full_fimp_df.Feature==prot].Importance.values
    y = full_perm_fimp_df[full_fimp_df.Feature==prot].Importance.values
    lwr, upr = (
        full_perm_fimp_df[
            full_fimp_df.Feature==prot
        ].
        Importance.
        describe(
            percentiles=[0.025,0.975]
        ).
        loc[['2.5%','97.5%']]
    )
    stat, pv = ks_2samp(x,y)
    sig_lst.append(
        pd.DataFrame(
            {'statistic' : stat,'pvalue' : pv,'perm_fimp_lwr' : lwr,'perm_fimp_upr' : upr},
            index=[prot]
        )
    )
sig_df = pd.concat(sig_lst)
sig_df['padj'] = multipletests(sig_df['pvalue'],method='bonferroni')[1]
sig_df

In [None]:
dat = \
(
    perf_df.
    loc[:,['roc_auc_mean','roc_auc_lwr','roc_auc_upr','set','Gene_name']].
    rename(
        columns={
            'roc_auc_mean' : 'perf',
            'roc_auc_lwr' : 'perf_lwr',
            'roc_auc_upr' : 'perf_upr'
        }
    ).
    reset_index().
    set_index('set').
    join(
        fimp_df.
        query('Feature!="Intercept" & Feature!="PGD"').
        set_index('set').
        loc[:,['mean','2.5%','97.5%']].
        rename(columns={'mean' : 'fimp','2.5%' : 'fimp_lwr','97.5%' : 'fimp_upr'})
    ).
    set_index('set_features').
    join(sig_df)
)
dat = dat.query('set_features!="PGD"')
dat = dat.query('set_features!="expired"')
dat = dat.query('set_features!="Survival"')

In [None]:
(
    dat.
    query(
        'perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & (fimp_lwr>0 | fimp_upr<0) & padj<0.0001'
    )
)

In [None]:
(
    dat.
    query(
        'perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & (fimp_lwr>0 | fimp_upr<0) & padj<0.0001'
    ).
    loc[:,
        ['perf_lwr',
         'perf','perf_upr','fimp_lwr',
         'fimp','fimp_upr','padj','Gene_name'
        ]
       ].
    sort_values('fimp',ascending=False)
).to_csv(dir_+'mortality_significant_prediction_stats_survival_wn_year.csv')

## Survival prediction with PGD covariate

In [None]:
perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival_wpgdcov.csv'
)
perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in perf_df.set_features]
perf_df = \
(
    perf_df.
    loc[:,['set','set_features']].
    set_index('set').
    join(
        generate_val_scores(
            pd.read_csv(
                dir_+'mortality_predictions_marker_patient_predictions_survival_wpgdcov.csv',
                index_col=0).set_index('set')
        )
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.set_index('Protein')
    ).
    sort_values('roc_auc_mean')
)

fimp_df = pd.read_csv(dir_+'mortality_predictions_marker_feature_importance_survival_wpgdcov.csv')
fimp_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in fimp_df.set_features]

In [None]:
full_fimp_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_full_feature_importance_survival_wpgdcov.csv',
    index_col=0
)
full_perm_fimp_df = \
pd.read_csv(
    dir_+'mortality_predictions_marker_full_permuted_feature_importance_survival_wpgdcov.csv',
    index_col=0
)
proteins = np.intersect1d(full_fimp_df.Feature.unique(),full_perm_fimp_df.Feature.unique())
sig_lst = []
for prot in proteins:
    x = full_fimp_df[full_fimp_df.Feature==prot].Importance.values
    y = full_perm_fimp_df[full_fimp_df.Feature==prot].Importance.values
    lwr, upr = (
        full_perm_fimp_df[
            full_fimp_df.Feature==prot
        ].
        Importance.
        describe(
            percentiles=[0.025,0.975]
        ).
        loc[['2.5%','97.5%']]
    )
    stat, pv = ks_2samp(x,y)
    sig_lst.append(
        pd.DataFrame(
            {'statistic' : stat,'pvalue' : pv,'perm_fimp_lwr' : lwr,'perm_fimp_upr' : upr},
            index=[prot]
        )
    )
sig_df = pd.concat(sig_lst)
sig_df['padj'] = multipletests(sig_df['pvalue'],method='bonferroni')[1]
sig_df

In [None]:
dat = \
(
    perf_df.
    loc[:,['roc_auc_mean','roc_auc_lwr','roc_auc_upr','set','Gene_name']].
    rename(
        columns={
            'roc_auc_mean' : 'perf',
            'roc_auc_lwr' : 'perf_lwr',
            'roc_auc_upr' : 'perf_upr'
        }
    ).
    reset_index().
    set_index('set').
    join(
        fimp_df.
        query('Feature!="Intercept" & Feature!="PGD"').
        set_index('set').
        loc[:,['mean','2.5%','97.5%']].
        rename(columns={'mean' : 'fimp','2.5%' : 'fimp_lwr','97.5%' : 'fimp_upr'})
    ).
    set_index('set_features').
    join(sig_df)
)
dat = dat.query('set_features!="PGD"')
dat = dat.query('set_features!="expired"')
dat = dat.query('set_features!="Survival"')

In [None]:
(
    dat.
    query(
        'perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & (fimp_lwr>0 | fimp_upr<0) & padj<0.0001'
    )
)

In [None]:
prots = \
(
    pd.
    read_csv(
        dir_+'mortality_significant_prediction_stats_survival.csv'
    ).
    loc[:,'Gene_name'].
    values
)

In [None]:
(
    dat.
    query(
        '(perf>0.5 & perm_fimp_lwr<=0 & perm_fimp_upr>=0 & \
        (fimp_lwr>0 | fimp_upr<0) & padj<0.0001) | \
        Gene_name in @prots'
    ).
    loc[:,
        ['perf_lwr',
         'perf','perf_upr','fimp_lwr',
         'fimp','fimp_upr','padj','Gene_name'
        ]
       ].
    sort_values('fimp',ascending=False)
).to_csv(dir_+'mortality_significant_prediction_stats_survival_wpgdcov.csv')

## PGD predictions (Figure S4)

In [None]:
perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival.csv'
)
perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in perf_df.set_features]
perf_df = \
(
    perf_df.
    loc[:,['set','set_features']].
    set_index('set').
    join(
        generate_val_scores(
            pd.read_csv(
                dir_+'mortality_predictions_marker_patient_predictions_survival.csv',
                index_col=0).set_index('set')
        )
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.set_index('Protein')
    ).
    sort_values('roc_auc_mean')
)
perf_df

In [None]:
pgd_perf_df = \
pd.read_csv(
    
    dir_+
    'mortality_predictions_marker_performance_survival_pgd.csv'
)
pgd_perf_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in pgd_perf_df.set_features]
pgd_perf_df = \
(
    pgd_perf_df.
    loc[:,['set','set_features']].
    set_index('set').
    join(
        generate_val_scores(
            pd.read_csv(
                dir_+'mortality_predictions_marker_patient_predictions_survival_pgd.csv',
                index_col=0).set_index('set')
        )
    ).
    reset_index().
    set_index('set_features').
    join(
        idmap_sub.set_index('Protein')
    ).
    sort_values('roc_auc_mean')
)
pgd_fimp_df = pd.read_csv(dir_+'mortality_predictions_marker_feature_importance_survival_pgd.csv')
pgd_fimp_df.set_features = [x.replace("['",'').replace("']",'').strip(", ") for x in pgd_fimp_df.set_features]

pgd_perf_df.to_csv(dir_+'mortality_predictions_marker_processed_patient_predictions_pgd.csv')

In [None]:
dat = \
(
    pgd_perf_df.
    loc[:,['roc_auc_mean','roc_auc_lwr','roc_auc_upr','set','Gene_name']].
    rename(
        columns={
            'roc_auc_mean' : 'perf',
            'roc_auc_lwr' : 'perf_lwr',
            'roc_auc_upr' : 'perf_upr'
        }
    ).
    reset_index().
    set_index('set').
    join(
        pgd_fimp_df.
        query('Feature!="Intercept"').
        set_index('set').
        loc[:,['mean','2.5%','97.5%']].
        rename(columns={'mean' : 'fimp','2.5%' : 'fimp_lwr','97.5%' : 'fimp_upr'})
    ).
    set_index('set_features')
)

dat = dat.query('set_features!="PGD"')
dat = dat.query('set_features!="expired"')
dat = dat.query('set_features!="Survival"')

dat.to_csv(dir_+'mortality_pgd_marker_predictions.csv')



In [None]:
dat = \
(perf_df.
 reset_index().
 set_index('set_features').
 rename(columns = {'roc_auc_mean' : 'mean'}).
 loc[:,['mean']].
 query('set_features!="PGD" & set_features!="expired"').
 join(
     pgd_perf_df.
     reset_index().
     set_index('set_features').
     rename(columns={'roc_auc_mean' : 'pgd_mean'}).
     loc[:,['pgd_mean']]
 ).
  query('set_features!="Survival"')
)


fig,ax = plt.subplots(dpi=300)
display(dat.sort_values('mean'))
display(dat.sort_values('pgd_mean'))
display(dat.query('mean>0.5 & pgd_mean>0.5').sort_values('pgd_mean'))
ax = sns.scatterplot('mean','pgd_mean',data=dat,ax=ax)
ax.set_xlabel('Survival prediction')
ax.set_ylabel('PGD prediction')

fig.tight_layout()
fig.savefig('../../docs/imgs/survival_vs_pgd_marker_prediction.png')

In [None]:
a=dat['mean'].values
b=dat['pgd_mean'].values
spearmanr(a,b)

## GSEA (Table S2 and Figure S5)

In [None]:
import gseapy as gp
gp.__version__

In [None]:
gs = ['GO_Biological_Process_2017b','GO_Molecular_Function_2017b',
      'GO_Cellular_Component_2017b','Reactome_2016','WikiPathways_2019_Human',
      'KEGG_2019_Human']
col_map = { 'nes' : 'Normalized Enrichment Score', 'pval' : 'P-value', 'fdr' : 'False Discovery Rate',"Category" : 'Category'}

In [None]:
joined = pd.read_csv('../../data/mortality_X_y_dedup.csv',index_col=0)
tmp = joined.expired.values
tmp2 = ['ALIVE' if x==0 else 'DIED' for x in tmp]
print(' '.join(str(x) for x in tmp2)) # for mortality.csl

In [None]:
phenoA, phenoB, class_vector =  gp.parser.gsea_cls_parser('../../data/mortality_dedup.cls')
print(phenoA)
print(phenoB)
class_vector

In [None]:
idmap_sub = pd.read_csv('../../data/protein_gene_map_full.csv').set_index('Protein')
dat = X_all_proteins.copy()
genes = [idmap_sub.loc[x].Gene_name for x in X_all_proteins.columns]
dat.columns = genes
dat.index.name = 'NAME'
dat = dat.T
dat

### Phenotype permutation

In [None]:
reses = {}
for g in gs:
    gs_res = gp.gsea(data=dat,
                 gene_sets=g, # enrichr library names
                 cls = '../../data/mortality_dedup.cls',
                 # set permutation_type to phenotype if samples >=15
                 permutation_type='phenotype',
                 permutation_num=1000, # reduce number to speed up test
                 outdir=None,  # do not write output to disk
                 no_plot=True, # Skip plotting
                 method='signal_to_noise',
                 processes=4, seed= 7,
                 format='png')
    gs_res.res2d['geneset'] = g
    reses[g] = gs_res

In [None]:
res2ds = pd.concat([reses[x].res2d for x in reses.keys()])
res2ds = res2ds.sort_values('fdr')

In [None]:
res2ds.to_csv('../../data/mortality_gsea_phenotype_permutation_results.csv')

In [None]:
terms = res2ds.index.values
genesets = res2ds.geneset

In [None]:
ranking_dict = {}
for g in gs:
    ranking_dict[g] = reses[g].ranking

In [None]:
from gseapy.plot import gseaplot, heatmap

In [None]:
for i in range(len(terms)):
    term = terms[i]
    geneset = genesets[i]
    ranking = ranking_dict[geneset]
    gseaplot(ranking, term=term, **reses[geneset].results[term],
             ofname='../../docs/imgs/all_fdr/gsea_plots/mortality_gsea'+term.replace('/','-')+'.png')

In [None]:
for i in range(len(terms)):
    term = terms[i]
    geneset = genesets[i]
    gs_res = reses[geneset]
    genes = gs_res.res2d.loc[term].ledge_genes.split(';')
    mat = gs_res.heatmat.copy()
    cols = [gs_res.heatmat.columns[i] + ' ' + class_vector[i] for i in range(len(gs_res.heatmat.columns))]
    mat.columns = cols
    heatmap(df = mat.loc[genes], z_score=0, 
            title=term,figsize=(24,6),
            ofname='../../docs/imgs/all_fdr/heatmaps/mortality_gsea_heatmap_'+term.replace('/','-')+'.png')

### Gene permutation

In [None]:
rnk = \
pd.read_csv(dir_+'survival_rank_statistic.csv')[['Gene_name','odds']].drop_duplicates(subset='Gene_name')
rnk.head()

In [None]:
for g in gs:
    print('\t'+g)
    pre_res = gp.prerank(rnk=rnk[['Gene_name','odds']], gene_sets=g,
                     processes=4,
                     permutation_num=10000,
                     outdir=dir_+'survival_gsea/'+g,format='png')
    
datas=[]
for g in gs:
    data = (
        pd.read_csv(
            dir_+'survival_gsea/'+g+'/gseapy.prerank.gene_sets.report.csv'
        ).sort_values(['fdr','nes'],ascending=[True,False]).
        query('fdr < 0.2 & (nes > 0 | nes < 0) & (fdr>pval | fdr==0)').
        rename(columns=col_map).
        set_index('Term')
    )
    data['Category'] = g
    datas.append(data)
pd.concat(datas).to_csv(dir_+'survival_gsea/'+
                       'prerank_report_all_categories.csv')
pd.concat(datas).shape

In [None]:
enriched = pd.concat(datas).copy()
tmp = enriched[[k for k in col_map.values()]].sort_values('False Discovery Rate',ascending=True).round(4)
tmp.sort_values('False Discovery Rate',ascending=True).to_csv(dir_+'survival_gsea/'+'pathways_functions.csv')
enriched.to_csv(dir_+'survival_gsea/'+'pathways_functions_wgenes.csv')
display(tmp.sort_values('False Discovery Rate',ascending=True))
print(tmp.shape[0])