# Load Data

In [1]:
# Base / Native
import math
import os
from os.path import join
import pickle
import re
import warnings
warnings.filterwarnings('ignore')

# Numerical / Array
import lifelines
from lifelines.statistics import logrank_test
from lifelines.utils import concordance_index as ci
from sksurv.metrics import cumulative_dynamic_auc
import numpy as np
import pandas as pd
import scipy
from scipy import interp
from scipy.stats import ttest_ind
from tqdm import tqdm

In [2]:
def getResultsFromPKL(dataroot: str, study_dir: str, split: str, idx: int, which_summary:str=''):
    results_df = pd.DataFrame(pd.read_pickle(join(dataroot, study_dir, 'split%s_%s_%d_results.pkl' % 
                                                  (which_summary, split, idx)))).T
    results_df = results_df.drop(['slide_id'], axis=1).astype(float)
    return results_df



def boot_cindex(results_df):
    bootci = []
    for seed in range(1000):
        bootstrap = results_df.sample(n=results_df.shape[0], replace=True, random_state=seed)
        bootci.append(ci(event_times=bootstrap['survival'], 
                         predicted_scores=-1*bootstrap['risk'],
                         event_observed=1-bootstrap['censorship']))
    return np.array(bootci)


def CI_pm(data, confidence=0.95):
    alpha = 0.95
    p = ((1.0-alpha)/2.0) * 100
    lower = max(0.0, np.percentile(data, p))
    p = (alpha+((1.0-alpha)/2.0)) * 100
    upper = min(1.0, np.percentile(data, p))
    return '(%0.3f-%0.3f)' % (lower, upper)



def hazard2grade(hazard, p):
    for i in range(len(p)):
        if hazard < p[i]:
            return i
    return len(p)


def getPValue_Binary(results_df: pd.DataFrame=None, risk_percentiles=[50]):
    p = np.percentile(results_df['risk'], risk_percentiles)
    results_df.insert(0, 'strat', [hazard2grade(risk, p) for risk in results_df['risk']])
    T_low, T_high = results_df['survival'][results_df['strat']==0], results_df['survival'][results_df['strat']==1]
    E_low, E_high = 1-results_df['censorship'][results_df['strat']==0], 1-results_df['censorship'][results_df['strat']==1]

    low_vs_high = logrank_test(durations_A=T_low, durations_B=T_high, event_observed_A=E_low, event_observed_B=E_high).p_value
    return np.array([low_vs_high])


def get_results_per_experiment(dataroot:str='./results/ICCV/', which_summary:str=''):
    results_df = []
    #print(dataroot)
    for study_name in sorted(os.listdir(dataroot)):
        #print(study_name)
        if os.path.isdir(join(dataroot, study_name)):
            tcga_proj = study_name.split('_')[1]
            summary = pd.read_csv(join(dataroot, study_name, 'summary%s.csv' % which_summary), 
                                  index_col=0, usecols=[1,2])
            val_cindex_mean = summary['val_cindex'].mean()
            val_cindex_std = np.std(summary['val_cindex'])
            val_cindex_ci = CI_pm(np.array(summary['val_cindex']))
            val_cindex_max = summary['val_cindex'].max()
            val_argmax = summary['val_cindex'].argmax()
            results_df.append([tcga_proj, val_cindex_mean, val_cindex_std, val_cindex_ci, val_cindex_max, val_argmax])

    results_df = pd.DataFrame(results_df, columns=['Project', 'Val C-Index (Mean)', 'Val C-Index (STD)',
                                                   'Val C-Index (CI)', 'Val C-Index (Max)', 'val_idx'])
    results_df.index = results_df['Project']
    return results_df


def survival_AUC(df1, df2, times=None, method='hc_sksurv'):
    df1['censorship'] = 1-df1['censorship']
    df2['censorship'] = 1-df2['censorship']
    times = np.percentile(df2['survival'], np.linspace(20, 81, 15))

    
    surv1 = np.array(df1, dtype=int)
    risk2 = np.array(df2['risk'])
    surv2 = np.array(df2.drop(['risk'], axis=1), dtype=int)

    surv1 = np.core.records.fromarrays(surv1[:, [1,0]].transpose(), names='obs, survival_months', formats = '?, i8')
    surv2 = np.core.records.fromarrays(surv2[:, [1,0]].transpose(), names='obs, survival_months', formats = '?, i8')
    _, iauc = cumulative_dynamic_auc(surv1, surv2, risk2, times)
    return iauc


def summarize_test(dataroot: str, exp_dir: str, which_summary: str, 
                   save_path:str, risk_percentiles: list=[50], verbose=False, 
                   exp_name=None, create_legend=True, multiplier=1, skip_km=True):
    
    dataroot = join(dataroot, exp_dir)
    results_cindex_df = get_results_per_experiment(dataroot=dataroot, which_summary=which_summary)
    cindex_df, pval_df = {}, {}
    ci_df = {}
    cindex_boot_df = []
    cindex_std_df = []
    figures = []

    for idx, study_dir in tqdm(enumerate(sorted(os.listdir(dataroot)))):
        if os.path.isdir(join(dataroot, study_dir)):
            proj = study_dir.split('_')[1]
            clin_df = pd.read_csv('./dataset_csv_sig/tcga_%s_all_clean.csv.zip' % proj)
            clin_df.index =clin_df['case_id']
            
            results_df = [getResultsFromPKL(dataroot=dataroot, study_dir=study_dir,
                                            split='val', idx=i,which_summary=which_summary) 
                          for i in range(5)]
            
            for idx, val_pred_df in enumerate(results_df):
                split_i = pd.read_csv(os.path.join('./splits/5foldcv/tcga_%s_100/' % proj, 'splits_%d.csv' % idx), index_col=0)
                train_pred_df = clin_df.loc[split_i['train']][['survival_months', 'censorship']]
                
                cin = ci(event_times=val_pred_df['survival'], 
                         predicted_scores=-1*val_pred_df['risk'],
                         event_observed=1-val_pred_df['censorship'])
                iauc = survival_AUC(train_pred_df, val_pred_df.drop('disc_label', axis=1))                
                cindex_std_df.append([cin, iauc, proj.lower(), exp_dir[:3]])
                
            results_df = pd.concat(results_df)
            results_df['case_id'] = results_df.index
            results_df = results_df.groupby('case_id').mean()
            cindex_df[proj] = np.array([ci(event_times=results_df['survival'], 
                                           predicted_scores=-1*results_df['risk'],
                                           event_observed=1-results_df['censorship'])])
            study = "_".join(study_dir.split('_')[:2])
            cindex_boot = 0 # boot_cindex(results_df)
            cindex_boot_df.append(pd.DataFrame([cindex_boot,
                                                [proj.upper()]*1000,
                                                [exp_dir[:3]]*1000]).T)
            ci_df[proj] = [CI_pm(cindex_boot)]
            
            
            if risk_percentiles == [50]:
                pval_df[proj] = getPValue_Binary(results_df=results_df.copy(), risk_percentiles=risk_percentiles)

    cindex_df = pd.DataFrame(cindex_df).T.sort_index()
    cindex_df.columns = ['C-Index (All)']
    pval_df = pd.DataFrame(pval_df).T.sort_index()
    pval_df.columns = ['P-Value']
    ci_df = pd.DataFrame(ci_df).T.sort_index()
    ci_df.columns = ['95% CI']
    results_df = results_cindex_df.join(cindex_df).join(pval_df).join(ci_df)
    
    if verbose:
        display(results_df)
        display(results_df.mean())
        display(pval_df.T < 0.05)
        print("Significant Studies:", (pval_df < 0.05).sum())
        
    cindex_boot_df = pd.concat(cindex_boot_df)
    cindex_boot_df.columns = ['C-Index', 'Study', 'Model']    
    cindex_std_df = pd.DataFrame(cindex_std_df)
    cindex_std_df.columns = ['C-Index', 'I-AUC', 'Project', 'Model']
    
    results_df['Val I-AUC (Mean)'] = cindex_std_df.groupby('Project').mean()['I-AUC']
    results_df['Val I-AUC (STD)'] = cindex_std_df.groupby('Project').std()['I-AUC']
    
    results_df = results_df.drop(['C-Index (All)', 'Val C-Index (Max)', '95% CI'], axis=1)
    return results_df, None, cindex_boot_df, cindex_std_df

### Loads predicted risks for each validation fold across 5-Fold CV evaluated across all models

In [3]:
results_all = {}
dataroot = './results/ICCV/'
for exp_dir in sorted(os.listdir(dataroot)):
    which_summary = '_latest'
    risk_percentiles = [50]
    save_path = os.path.join(dataroot, exp_dir)
    os.makedirs(save_path, exist_ok=True)
    
    if not os.path.isfile('%s/cindex.csv' % save_path):
        results, km, boot, std = summarize_test(dataroot=dataroot, exp_dir=exp_dir, which_summary=which_summary, 
                                             save_path=save_path, risk_percentiles=risk_percentiles, 
                                             verbose=False, exp_name='N/A', create_legend=False, multiplier=1)
        results.to_csv('%s/cindex.csv' % save_path)
    else:
        results = pd.read_csv('%s/cindex.csv' % save_path)
        index = results['Project'].str.contains('|'.join(['blca', 'brca', 'gbmlgg', 'luad', 'ucec']))
        results = results[index].reset_index(drop=True)
            
    results_all[exp_dir] = (results)

# Validation (C-Index Metric)

In [4]:
summary_df = []
overall = []

compare = ['MCATsm_nll_surv_a0.0_5foldcv_gc32_concat']
compare_names = ['MCAT (Ours)']

order = ['SNNsm_nll_surv_a0.0_reg1e-04_5foldcv_gc32_sig',
         'DSsm_nll_surv_a0.0_5foldcv_gc32_sig', 'DSsm_nll_surv_a0.0_5foldcv_gc32_sig_concat', 'DSsm_nll_surv_a0.0_5foldcv_gc32_sig_bilinear',
         'AMILsm_nll_surv_a0.0_5foldcv_gc32_sig', 'AMILsm_nll_surv_a0.0_5foldcv_gc32_sig_concat', 'AMILsm_nll_surv_a0.0_5foldcv_gc32_sig_bilinear',
         'MIFCNsm_nll_surv_a0.0_5foldcv_gc32_sig', 'MIFCNsm_nll_surv_a0.0_5foldcv_gc32_sig_concat', 'MIFCNsm_nll_surv_a0.0_5foldcv_gc32_sig_bilinear'] + compare

names = ['SNN (Genomic Only)', 
         'Deep Sets (WSI Only)', 'Deep Sets (Concat)', 'Deep Sets (Bilinear Pooling)', 
         'Attention MIL (WSI Only)', 'Attention MIL (Concat)', 'Attention MIL (Bilinear Pooling)', 
         'DeepAttnMISL (WSI Only)', 'DeepAttnMISL (Concat)', 'MI-FCN (Bilinear Pooling)'] + compare_names
to_eval = ['BLCA', 'BRCA', 'GBMLGG', 'LUAD', 'UCEC']

means = []
for exp in order:
    results = results_all[exp]
    results.index = results['Project'].str.upper()
    means.append(results['Val C-Index (Mean)'].loc[to_eval].mean())
    exp_cin = results['Val C-Index (Mean)'].map('{:.3f}'.format).astype(str) + ' +/- '
    exp_cin = exp_cin.str.cat(results['Val C-Index (STD)'].map('{:.3f}'.format), sep='')
    #exp_pval = results['P-Value'].map('{:01.2e}'.format)
    summary_df.append(exp_cin)

summary_df = pd.concat(summary_df, axis=1).T
summary_df.columns = results['Project'].str.upper()
summary_df.index = names
means = pd.Series(means, index=names).map('{:.3f}'.format)
summary_df = pd.concat([summary_df[to_eval], means], axis=1)
summary_df.columns = list(pd.Series(to_eval).str.upper()) + ['Overall']
print(summary_df.to_latex())
summary_df

\begin{tabular}{lllllll}
\toprule
{} &             BLCA &             BRCA &           GBMLGG &             LUAD &             UCEC & Overall \\
\midrule
SNN (Genomic Only)               &  0.541 +/- 0.016 &  0.466 +/- 0.058 &  0.598 +/- 0.054 &  0.539 +/- 0.069 &  0.493 +/- 0.096 &   0.527 \\
Deep Sets (WSI Only)             &  0.500 +/- 0.000 &  0.500 +/- 0.000 &  0.498 +/- 0.014 &  0.496 +/- 0.008 &  0.500 +/- 0.000 &   0.499 \\
Deep Sets (Concat)               &  0.604 +/- 0.042 &  0.521 +/- 0.079 &  0.803 +/- 0.046 &  0.616 +/- 0.027 &  0.598 +/- 0.077 &   0.629 \\
Deep Sets (Bilinear Pooling)     &  0.589 +/- 0.050 &  0.522 +/- 0.029 &  0.809 +/- 0.027 &  0.558 +/- 0.038 &  0.593 +/- 0.055 &   0.614 \\
Attention MIL (WSI Only)         &  0.536 +/- 0.038 &  0.564 +/- 0.050 &  0.787 +/- 0.028 &  0.559 +/- 0.060 &  0.625 +/- 0.057 &   0.614 \\
Attention MIL (Concat)           &  0.605 +/- 0.045 &  0.551 +/- 0.077 &  0.816 +/- 0.011 &  0.563 +/- 0.050 &  0.614 +/- 0.052 &   0.630 \\


Unnamed: 0,BLCA,BRCA,GBMLGG,LUAD,UCEC,Overall
SNN (Genomic Only),0.541 +/- 0.016,0.466 +/- 0.058,0.598 +/- 0.054,0.539 +/- 0.069,0.493 +/- 0.096,0.527
Deep Sets (WSI Only),0.500 +/- 0.000,0.500 +/- 0.000,0.498 +/- 0.014,0.496 +/- 0.008,0.500 +/- 0.000,0.499
Deep Sets (Concat),0.604 +/- 0.042,0.521 +/- 0.079,0.803 +/- 0.046,0.616 +/- 0.027,0.598 +/- 0.077,0.629
Deep Sets (Bilinear Pooling),0.589 +/- 0.050,0.522 +/- 0.029,0.809 +/- 0.027,0.558 +/- 0.038,0.593 +/- 0.055,0.614
Attention MIL (WSI Only),0.536 +/- 0.038,0.564 +/- 0.050,0.787 +/- 0.028,0.559 +/- 0.060,0.625 +/- 0.057,0.614
Attention MIL (Concat),0.605 +/- 0.045,0.551 +/- 0.077,0.816 +/- 0.011,0.563 +/- 0.050,0.614 +/- 0.052,0.63
Attention MIL (Bilinear Pooling),0.567 +/- 0.034,0.536 +/- 0.074,0.812 +/- 0.005,0.578 +/- 0.036,0.562 +/- 0.058,0.611
DeepAttnMISL (WSI Only),0.504 +/- 0.042,0.524 +/- 0.043,0.734 +/- 0.029,0.548 +/- 0.050,0.597 +/- 0.059,0.581
DeepAttnMISL (Concat),0.611 +/- 0.049,0.545 +/- 0.071,0.805 +/- 0.014,0.595 +/- 0.061,0.615 +/- 0.020,0.634
MI-FCN (Bilinear Pooling),0.575 +/- 0.032,0.577 +/- 0.063,0.813 +/- 0.022,0.551 +/- 0.038,0.586 +/- 0.036,0.621


# Validation (Integrated AUC Metric)

In [5]:
summary_df = []
overall = []

compare = ['MCATsm_nll_surv_a0.0_5foldcv_gc32_concat']
compare_names = ['MCAT (Ours)']

order = ['SNNsm_nll_surv_a0.0_reg1e-04_5foldcv_gc32_sig',
         'DSsm_nll_surv_a0.0_5foldcv_gc32_sig', 'DSsm_nll_surv_a0.0_5foldcv_gc32_sig_concat', 'DSsm_nll_surv_a0.0_5foldcv_gc32_sig_bilinear',
         'AMILsm_nll_surv_a0.0_5foldcv_gc32_sig', 'AMILsm_nll_surv_a0.0_5foldcv_gc32_sig_concat', 'AMILsm_nll_surv_a0.0_5foldcv_gc32_sig_bilinear',
         'MIFCNsm_nll_surv_a0.0_5foldcv_gc32_sig', 'MIFCNsm_nll_surv_a0.0_5foldcv_gc32_sig_concat', 'MIFCNsm_nll_surv_a0.0_5foldcv_gc32_sig_bilinear'] + compare

names = ['SNN (Genomic Only)', 
         'Deep Sets (WSI Only)', 'Deep Sets (Concat)', 'Deep Sets (Bilinear Pooling)', 
         'Attention MIL (WSI Only)', 'Attention MIL (Concat)', 'Attention MIL (Bilinear Pooling)', 
         'DeepAttnMISL (WSI Only)', 'DeepAttnMISL (Concat)', 'MI-FCN (Bilinear Pooling)'] + compare_names
to_eval = ['BLCA', 'BRCA', 'GBMLGG', 'LUAD', 'UCEC']

means = []
for exp in order:
    results = results_all[exp]
    results.index = results['Project'].str.upper()
    means.append(results['Val I-AUC (Mean)'].loc[to_eval].mean())
    exp_cin = results['Val I-AUC (Mean)'].map('{:.3f}'.format).astype(str) + ' +/- '
    exp_cin = exp_cin.str.cat(results['Val C-Index (STD)'].map('{:.3f}'.format), sep='')
    #exp_pval = results['P-Value'].map('{:01.2e}'.format)
    summary_df.append(exp_cin)

summary_df = pd.concat(summary_df, axis=1).T
summary_df.columns = results['Project'].str.upper()
summary_df.index = names
means = pd.Series(means, index=names).map('{:.3f}'.format)
summary_df = pd.concat([summary_df[to_eval], means], axis=1)
summary_df.columns = list(pd.Series(to_eval).str.upper()) + ['Overall']
print(summary_df.to_latex())
summary_df

\begin{tabular}{lllllll}
\toprule
{} &             BLCA &             BRCA &           GBMLGG &             LUAD &             UCEC & Overall \\
\midrule
SNN (Genomic Only)               &  0.537 +/- 0.016 &  0.476 +/- 0.058 &  0.627 +/- 0.054 &  0.540 +/- 0.069 &  0.476 +/- 0.096 &   0.531 \\
Deep Sets (WSI Only)             &  0.500 +/- 0.000 &  0.500 +/- 0.000 &  0.495 +/- 0.014 &  0.495 +/- 0.008 &  0.500 +/- 0.000 &   0.498 \\
Deep Sets (Concat)               &  0.630 +/- 0.042 &  0.537 +/- 0.079 &  0.832 +/- 0.046 &  0.625 +/- 0.027 &  0.627 +/- 0.077 &   0.650 \\
Deep Sets (Bilinear Pooling)     &  0.616 +/- 0.050 &  0.539 +/- 0.029 &  0.846 +/- 0.027 &  0.558 +/- 0.038 &  0.618 +/- 0.055 &   0.635 \\
Attention MIL (WSI Only)         &  0.533 +/- 0.038 &  0.591 +/- 0.050 &  0.807 +/- 0.028 &  0.574 +/- 0.060 &  0.668 +/- 0.057 &   0.635 \\
Attention MIL (Concat)           &  0.631 +/- 0.045 &  0.565 +/- 0.077 &  0.856 +/- 0.011 &  0.566 +/- 0.050 &  0.618 +/- 0.052 &   0.647 \\


Unnamed: 0,BLCA,BRCA,GBMLGG,LUAD,UCEC,Overall
SNN (Genomic Only),0.537 +/- 0.016,0.476 +/- 0.058,0.627 +/- 0.054,0.540 +/- 0.069,0.476 +/- 0.096,0.531
Deep Sets (WSI Only),0.500 +/- 0.000,0.500 +/- 0.000,0.495 +/- 0.014,0.495 +/- 0.008,0.500 +/- 0.000,0.498
Deep Sets (Concat),0.630 +/- 0.042,0.537 +/- 0.079,0.832 +/- 0.046,0.625 +/- 0.027,0.627 +/- 0.077,0.65
Deep Sets (Bilinear Pooling),0.616 +/- 0.050,0.539 +/- 0.029,0.846 +/- 0.027,0.558 +/- 0.038,0.618 +/- 0.055,0.635
Attention MIL (WSI Only),0.533 +/- 0.038,0.591 +/- 0.050,0.807 +/- 0.028,0.574 +/- 0.060,0.668 +/- 0.057,0.635
Attention MIL (Concat),0.631 +/- 0.045,0.565 +/- 0.077,0.856 +/- 0.011,0.566 +/- 0.050,0.618 +/- 0.052,0.647
Attention MIL (Bilinear Pooling),0.581 +/- 0.034,0.549 +/- 0.074,0.849 +/- 0.005,0.580 +/- 0.036,0.557 +/- 0.058,0.623
DeepAttnMISL (WSI Only),0.488 +/- 0.042,0.534 +/- 0.043,0.753 +/- 0.029,0.569 +/- 0.050,0.628 +/- 0.059,0.594
DeepAttnMISL (Concat),0.634 +/- 0.049,0.551 +/- 0.071,0.839 +/- 0.014,0.602 +/- 0.061,0.622 +/- 0.020,0.649
MI-FCN (Bilinear Pooling),0.580 +/- 0.032,0.592 +/- 0.063,0.848 +/- 0.022,0.549 +/- 0.038,0.594 +/- 0.036,0.633
