In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import math
import warnings
import sklearn
import random
import xgboost as xgb
from sklearn.model_selection import train_test_split
from sklearn.model_selection import ParameterGrid
from sklearn.metrics import roc_curve, auc, f1_score, accuracy_score
from sklearn.metrics import precision_recall_curve, average_precision_score
from sksurv.metrics import cumulative_dynamic_auc, concordance_index_censored
import ast

warnings.filterwarnings("ignore")

In [None]:
def train_val_split(deriv_data, shuffle=True, random_state=42):
    # Divide patients to train / validation / groups
    
    random.seed(random_state)
    # Divide patients to train / validation / groups
    
    patient_list = deriv_data['henkilotunnus'].unique()
    
    if shuffle == True:
        random.shuffle(patient_list)
    
    # Calculate the number of items in each sublist
    total_items = len(patient_list)
    train_size = int(total_items * 0.75)
    val_size = total_items - train_size  # To ensure all items are included

    # Divide the list into sublists
    train_list = patient_list[:train_size]
    val_list = patient_list[train_size:]
    
    train_data = deriv_data[deriv_data['henkilotunnus'].isin(train_list)].reset_index(drop=True)
    val_data = deriv_data[deriv_data['henkilotunnus'].isin(val_list)].reset_index(drop=True)

    return train_data, val_data

In [None]:
my_path = '~/mounts/research/husdatalake/disease/scripts/Preleukemia/oona_git'

In [None]:
fs=15

In [None]:
disease = 'MDS'

### Static vs all features cross-validation

In [None]:
cv = 10
nrounds = 1000
early_stop = 10

In [None]:
for metric in ['ROC','PR']:
    
    static_AUCs = []
    static_AUCPRs = []
    all_AUCs = []
    all_AUCPRs = []
    
    print('\n', disease)
    
    deriv_data = pd.read_csv(my_path + '/data/modelling/' + disease + '_derivation_data.csv')
    
    # Drop e_retic columns
    deriv_data = deriv_data.loc[:, ~deriv_data.columns.str.startswith('e_retic')]
    all_features = list(deriv_data.columns)
    basic_features = ['henkilotunnus', 'time_to_dg', 'disease','sukupuoli_selite', 'age', 'rows_in_last_month']
    
    # Read hyperparameters    
    hyperparams = pd.read_csv(my_path + '/optimization/hyperparams/' + disease + '_hyperparameter_results_cv.csv', index_col=0)
    max_idx = hyperparams['AUCPR_mean'].idxmax()  #f1_score_mean
    params = ast.literal_eval(hyperparams['params'].loc[max_idx])
    
    include = []
    for feat in all_features:
        if ( 'norm' in feat ):
            include.append(feat)
                
    basic_features.extend(include)
    
    feature_pool = [x for x in all_features if x not in basic_features]
    
    for feature_type in ['static', '']:
    
        include = []
        for feat in feature_pool:
            if feature_type in feat:
                include.append(feat)
    
        deriv_iter = deriv_data[basic_features + include]
    
        if feature_type == '':
            feature_type = 'all'
            
        print('\nTraining model with features : ', feature_type)
    
        print('')
        print(feature_type)
        print(f'Using {len(list(deriv_iter.columns))} features')
        print('')
    
        fig = plt.figure(figsize=(6,6))
        
        # Do cross validation
        for i in range(cv):
        
            print('\n\tCV loop no: ', i+1)
    
            # Train model & evaluate
            train_data, validation_data = train_val_split(deriv_iter, shuffle=True, random_state=i+1)

            # Separate features and target variables
            x_train = train_data.drop(columns=['henkilotunnus', 'disease', 'time_to_dg'])
            y_train = train_data['time_to_dg']
    
            x_val = validation_data.drop(columns=['henkilotunnus', 'disease', 'time_to_dg'])
            y_val = validation_data['time_to_dg']
    
            # Create DMatrix for XGBoost
            dtrain = xgb.DMatrix(x_train, label=y_train)
            dval = xgb.DMatrix(x_val, label=y_val)
    
            # Use validation set to watch performance
            watchlist = [(dtrain,'train'), (dval,'eval')]
    
            # Store validation results
            evals_results = {}
    
            # Train the model
            print(f'\nTraining the model with parameters: ')
            print(params)
    
            xgb_model = xgb.train(params, dtrain, num_boost_round=nrounds, early_stopping_rounds=early_stop, evals=watchlist, evals_result=evals_results, verbose_eval=50)
    
            # Predict risk scores
            risk_scores_train = xgb_model.predict(dtrain)
            risk_scores_val = xgb_model.predict(dval)
    
            # Add risk scores to the dataframe
            train_data['risk_score'] = risk_scores_train
            validation_data['risk_score'] = risk_scores_val
    
            # Negative times to positive for getting c-index
            validation_data['time_to_dg'] = validation_data['time_to_dg'].apply(lambda x: -x if x < 0 else x)
            #c_index = concordance_index_censored(event_indicator=validation_data['disease'].replace({0 : False, 1 : True}), event_time=validation_data['time_to_dg'], estimate=validation_data['risk_score'])[0]
    
            # Calculate C-index for validation set
            #c_index = concordance_index(validation_data['time_to_dg'], -validation_data['risk_score'], validation_data['disease'])
            fpr, tpr, thresholds = roc_curve(validation_data['disease'], validation_data['risk_score'])
            roc_auc = auc(fpr, tpr)
    
            # Calculate precision and recall
            precision, recall, pr_thresholds = precision_recall_curve(validation_data['disease'], validation_data['risk_score'])
            average_precision = average_precision_score(validation_data['disease'], validation_data['risk_score'])
    
            if feature_type == 'static':
                color = "#052f82"
            else:
                color = "#820e05"
                
            if metric == 'ROC':
                plt.plot(fpr, tpr, lw=2, label=f'{feature_type} (AUC = {round(roc_auc,2)})', color=color)
                plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.2)
                plt.xlim([0.0, 1.0])
                plt.ylim([0.0, 1.05])
                plt.ylabel('Sensitivity', fontsize=fs)
                plt.xlabel('1 - Specificity',fontsize=fs)
            
            else:
                plt.step(recall, precision, where='post', label=f'{feature_type} (AUPRC = {round(average_precision,2)})', color=color)
                plt.xlabel('Recall', fontsize=fs)
                plt.ylabel('Precision', fontsize=fs)
    
            if feature_type == 'all':
                all_AUCs.append(roc_auc)
                all_AUCPRs.append(average_precision)
            else:
                static_AUCs.append(roc_auc)
                static_AUCPRs.append(average_precision)
    
        if feature_type == 'all':
            mean_AUC = np.mean(all_AUCs)
            mean_AUCPR = np.mean(all_AUCPRs)
        else:
            mean_AUC = np.mean(static_AUCs)
            mean_AUCPR = np.mean(static_AUCPRs)
        
        plt.xticks(fontsize=fs, rotation=0)
        plt.yticks(fontsize=fs, rotation=0)
        
        if metric == 'ROC':
            plt.title(f'{feature_type} features mean AUC= {mean_AUC}', fontsize=fs)
        else:
            plt.title(f'{feature_type} features mean AUPRC= {mean_AUCPR}', fontsize=fs)
        #plt.legend(loc='best')
        sns.despine(fig=fig, ax=None, top=True, right=True, left=False, bottom=False, offset=None, trim=False)
        plt.show()
    
        if metric == 'ROC':
            fig.savefig('results/final_model/plots/static_vs_all_features/cross_validation/' + disease + '_roc_' + feature_type + '.png')
        else:
            fig.savefig('results/final_model/plots/static_vs_all_features/cross_validation/' + disease + '_pr_' + feature_type + '.png')
    
            


## Create vector across 10 cv loops for delong test

In [None]:
cv_vector = pd.DataFrame()

In [None]:
cv=10

In [None]:
print('\n', disease)

cv_ytrue_static = []
cv_ytrue_all = []
cv_ypred_static = []
cv_ypred_all = []

deriv_data = pd.read_csv(my_path + '/data/modelling/' + disease + '_derivation_data.csv')

# Drop e_retic columns
deriv_data = deriv_data.loc[:, ~deriv_data.columns.str.startswith('e_retic')]
all_features = list(deriv_data.columns)
basic_features = ['henkilotunnus', 'time_to_dg', 'disease','sukupuoli_selite', 'age', 'rows_in_last_month']

# Read hyperparameters    
hyperparams = pd.read_csv(my_path + '/optimization/hyperparams/' + disease + '_hyperparameter_results_cv.csv', index_col=0)
max_idx = hyperparams['AUCPR_mean'].idxmax()  #f1_score_mean
params = ast.literal_eval(hyperparams['params'].loc[max_idx])

include = []
for feat in all_features:
        if ( 'norm' in feat ):
            include.append(feat)
            
basic_features.extend(include)

feature_pool = [x for x in all_features if x not in basic_features]

deriv_static = deriv_data[basic_features]

include = []
for feat in feature_pool:
    if '' in feat:
        include.append(feat)

deriv_all = deriv_data[basic_features + include]


# Do cross validation
for i in range(cv):

    print('\n\tCV loop no: ', i+1)

    print('\nTraining model with static features')
    
    # Train model & evaluate
    train_data_static, validation_data_static = train_val_split(deriv_static, shuffle=True, random_state=i+1)
    print(len(validation_data_static))
    
    # Separate features and target variables
    x_train = train_data_static.drop(columns=['henkilotunnus', 'disease', 'time_to_dg'])
    y_train = train_data_static['time_to_dg']
    
    x_val = validation_data_static.drop(columns=['henkilotunnus', 'disease', 'time_to_dg'])
    y_val = validation_data_static['time_to_dg']
    
    # Create DMatrix for XGBoost
    dtrain = xgb.DMatrix(x_train, label=y_train)
    dval = xgb.DMatrix(x_val, label=y_val)
    
    # Use validation set to watch performance
    watchlist = [(dtrain,'train'), (dval,'eval')]
    
    # Store validation results
    evals_results = {}
    
    # Train the model
    print(f'\nTraining the model with parameters: ')
    print(params)
    
    xgb_model_static = xgb.train(params, dtrain, num_boost_round=nrounds, early_stopping_rounds=early_stop, evals=watchlist, evals_result=evals_results, verbose_eval=50)
    
    y_true_static = validation_data_static['disease']
    y_pred_static = xgb_model_static.predict(dval)
    
    print('\nTraining model with all features')
    
    # Train model & evaluate
    train_data_all, validation_data_all = train_val_split(deriv_all, shuffle=True, random_state=i+1)
    print(len(validation_data_all))
    
    # Separate features and target variables
    x_train = train_data_all.drop(columns=['henkilotunnus', 'disease', 'time_to_dg'])
    y_train = train_data_all['time_to_dg']
    
    x_val = validation_data_all.drop(columns=['henkilotunnus', 'disease', 'time_to_dg'])
    y_val = validation_data_all['time_to_dg']
    
    # Create DMatrix for XGBoost
    dtrain = xgb.DMatrix(x_train, label=y_train)
    dval = xgb.DMatrix(x_val, label=y_val)
    
    # Use validation set to watch performance
    watchlist = [(dtrain,'train'), (dval,'eval')]
    
    # Store validation results
    evals_results = {}
    
    # Train the model
    print(f'\nTraining the model with parameters: ')
    print(params)
    
    xgb_model_all = xgb.train(params, dtrain, num_boost_round=nrounds, early_stopping_rounds=early_stop, evals=watchlist, evals_result=evals_results, verbose_eval=50)
    
    y_true_all = validation_data_all['disease']
    y_pred_all = xgb_model_all.predict(dval)
    
    cv_ytrue_static.extend(y_true_static.to_list())
    cv_ytrue_all.extend(y_true_all.to_list())
    cv_ypred_static.extend(list(y_pred_static))
    cv_ypred_all.extend(list(y_pred_all))
    print(len(cv_ytrue_static))


# After the cv loop
cv_vector['y_true'] = cv_ytrue_static
cv_vector['y_pred_static'] = cv_ypred_static
cv_vector['y_pred_all'] = cv_ypred_all

In [None]:
cv_vector

In [None]:
cv_vector.to_csv('results/final_model/plots/static_vs_all_features/' + disease + '_static_vs_all_delong_cv_vector.csv', index=False)

## READ CV VECTOR HERE

In [None]:
cv_vector = pd.read_csv('results/final_model/plots/static_vs_all_features/' + disease + '_static_vs_all_delong_cv_vector.csv')

## Plot static vs all roc & pr curves across 10-cv

In [None]:
static_fpr, static_tpr, static_thresholds = roc_curve(cv_vector['y_true'], cv_vector['y_pred_static'])
static_roc_auc = auc(static_fpr, static_tpr)


all_fpr, all_tpr, all_thresholds = roc_curve(cv_vector['y_true'], cv_vector['y_pred_all'])
all_roc_auc = auc(all_fpr, all_tpr)

In [None]:
if disease == 'de_novo_AML':
    name = 'De novo AML'

if disease == 'primary_MF':
    name = 'Primary MF'

if disease == 'any_MN':
    name = 'Any MN'

if disease == 'MDS':
    name = 'MDS'

In [None]:
fig = plt.figure(figsize=(6,6))
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--', alpha=0.2)
plt.plot(static_fpr, static_tpr, lw=3, label='Static (AUROC = %0.2f)' % static_roc_auc, color = "#052f82") #, color=any_MN_color)
plt.plot(all_fpr, all_tpr, lw=3, label='All (AUROC = %0.2f)' % all_roc_auc, color = "#820e05") #label='Any MN (AUROC = %0.2f)' % any_MN_roc_auc, color=any_MN_color)
sns.despine(fig=fig, ax=None, top=True, right=True, left=False, bottom=False, offset=None, trim=False)
plt.legend()
plt.ylabel('Sensitivity', fontsize=fs)
plt.xlabel('1 - Specificity',fontsize=fs)
plt.xticks(fontsize=fs, rotation=0)
plt.yticks(fontsize=fs, rotation=0)
plt.title(name, loc='left', fontsize=fs)
fig.savefig('results/final_model/plots/static_vs_all_features/' + disease + '_static_vs_all_roc.png')

In [None]:
static_precision, static_recall, static_thresholds = precision_recall_curve(cv_vector['y_true'], cv_vector['y_pred_static'])
static_average_precision = average_precision_score(cv_vector['y_true'], cv_vector['y_pred_static'])

In [None]:
all_precision, all_recall, all_thresholds = precision_recall_curve(cv_vector['y_true'], cv_vector['y_pred_all'])
all_average_precision = average_precision_score(cv_vector['y_true'], cv_vector['y_pred_all'])

In [None]:
fig = plt.figure(figsize=(6,6))
plt.step(static_recall, static_precision, where='post', label=f'Static (AUPRC = {static_average_precision:.2f})',lw=3, color = "#052f82")
plt.step(all_recall, all_precision, where='post', label=f'All (AUPRC = {all_average_precision:.2f})',lw=3, color = "#820e05")
sns.despine(fig=fig, ax=None, top=True, right=True, left=False, bottom=False, offset=None, trim=False)
plt.xlabel('Recall', fontsize=fs)
plt.ylabel('Precision', fontsize=fs)
plt.xticks(fontsize=fs, rotation=0)
plt.yticks(fontsize=fs, rotation=0)
plt.legend()
plt.title(name, loc='left', fontsize=fs)
fig.savefig('results/final_model/plots/static_vs_all_features/' + disease + '_static_vs_all_pr.png')