# Survival Analysis Prediction

In [208]:
%matplotlib inline
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

from sksurv.nonparametric import kaplan_meier_estimator
from sksurv.linear_model import CoxPHSurvivalAnalysis
from sksurv.metrics import concordance_index_censored
from sklearn.metrics import roc_curve, auc, balanced_accuracy_score, precision_score, recall_score, f1_score

from sklearn.feature_selection import SelectKBest
from sklearn.pipeline import Pipeline
from sklearn.model_selection import GridSearchCV, KFold

from mm_survival.pipelines import run_survival_analysis_ml

In [209]:
tpm_rna_filename = 'data/gene_expression/MMRF_CoMMpass_IA9_E74GTF_Salmon_entrezID_TPM_hg19.csv'
count_rna_file = 'data/gene_expression/MMRF_CoMMpass_IA9_E74GTF_Salmon_Gene_Counts.txt'
clinical_file = 'data/clinical/sc3_Training_ClinAnnotations.csv'
DE_genes_filename = 'data/gene_expression/differential_expression/DE_genes.txt'
signature_genes_filename = 'data/gene_expression/differential_expression/signature_genes.txt'

df_train, df_train_censored, df_clin_uncensored, df_clin_censored = run_survival_analysis_ml(tpm_rna_filename, count_rna_file, clinical_file,
            DE_genes_filename, signature_genes_filename, 'RF', 200, top_k_genes=40)


We have 57997 genes in the raw counts gene expression matrix
We have 24128 genes in the TPM normalized gene expression matrix
N° of patients in the MMRF cohort, with RNAseq available RNA-seq data: 735
N° of patients in the MMRF cohort, with RNAseq available TPM-normalized RNA-seq data: 735
Number of patients with clinical and sequencing data: 582
Total number of genes in the dataset: 100
(391, 42)


In [210]:
df_clin_censored['D_Status'] = df_clin_censored['D_OS_FLAG'].astype(bool)
df_clin_uncensored['D_Status'] = df_clin_uncensored['D_OS_FLAG'].astype(bool)

Labels for survival analysis

In [211]:
data_y = np.array([(df_clin_uncensored.iloc[i]['D_Status'].astype(bool), df_clin_uncensored.iloc[i]['D_OS']) for i in range(df_clin_uncensored.shape[0])],
                 dtype=[('D_Status', bool), ('D_OS', np.int64)])
data_y_censored = np.array([(df_clin_censored.iloc[i]['D_Status'].astype(bool), df_clin_censored.iloc[i]['D_OS']) for i in range(df_clin_censored.shape[0])],
                 dtype=[('D_Status', bool), ('D_OS', np.int64)])

HR labels

In [212]:
y_hr = df_clin_uncensored['HR_FLAG'].replace({'TRUE': 1, 'FALSE': 0})

In [198]:
data_x = df_train.copy()
data_x_censored  = df_train_censored.copy()

**Train/test splits**

In [199]:
df_train, df_test, y_train , y_test, y_hr_train, y_hr_test = train_test_split(data_x, data_y, y_hr, test_size=0.2)

df_train.shape, df_test.shape, y_train.shape, y_test.shape, y_hr_train.shape, y_hr_test.shape

((312, 42), (79, 42), (312,), (79,), (312,), (79,))

**Augment training set**

In [200]:
df_train = pd.concat([df_train, df_train_censored], axis=0)
y_train = np.hstack([y_train, data_y_censored])
y_hr_train = np.hstack([y_hr_train, ]
                      )
df_train.shape, y_train.shape

((503, 42), (503,))

**Penalized Cox models**

In [201]:
estimator = CoxPHSurvivalAnalysis(alpha=0.01)
estimator.fit(df_train, y_train)



CoxPHSurvivalAnalysis(alpha=0.01)

**Concordance index**

In [202]:
pred_test = estimator.predict(df_test)
pred_train = estimator.predict(df_train)

score_train = concordance_index_censored(y_train["D_Status"], y_train["D_OS"], pred_train)
score_test = concordance_index_censored(y_test["D_Status"], y_test["D_OS"], pred_test)

score_train[0], score_test[0]



(0.7795409786396, 0.6684378320935175)

**Classification metrics**

In [203]:
preds_train = estimator.predict_survival_function(df_train)
preds_test = estimator.predict_survival_function(df_test)



In [204]:
preds_train.shape, preds_test.shape

((503,), (79,))

In [205]:
pred_hr_train, pred_hr_test = [], []

for pred in preds_train:
    for i, t in enumerate(pred.x):
        if t >= 540:
            p = 1 - pred.y[i]
            pred_hr_train.append(p > 0.5)
            break
            
for pred in preds_test:
    for i, t in enumerate(pred.x):
        if t >= 540:
            p = 1 - pred.y[i]
            pred_hr_test.append(p > 0.5)
            break

**Train Metrics**

In [206]:
N = y_hr_train.shape[0]
acc = accuracy_score(pred_hr_train[:N], y_hr_train)
fpr, tpr, _ = roc_curve(pred_hr_train[:N], y_hr_train)
auc_score = auc(fpr, tpr)
recall = recall_score(pred_hr_train[:N], y_hr_train)
precision = precision_score(pred_hr_train[:N], y_hr_train)

acc, auc_score, recall, precision

(0.7051282051282052,
 0.8163509471585244,
 0.9411764705882353,
 0.14953271028037382)

**Test Metrics**

In [207]:
acc = accuracy_score(pred_hr_test, y_hr_test)
fpr, tpr, _ = roc_curve(pred_hr_test, y_hr_test)
auc_score = auc(fpr, tpr)
recall = recall_score(pred_hr_test, y_hr_test)
precision = precision_score(pred_hr_test, y_hr_test)

acc, auc_score, recall, precision

(0.7215189873417721, 0.858974358974359, 1.0, 0.043478260869565216)