 [ ] To add diseases tests
    - add to nan handling
    - make sure that it does not destroy the dataset
    - make sure that it is in the dataset
    - train the best rsf with the updated dataset and compare results
- Look for other features of interest for deceased donors (cold ischemia time)

In [1]:
import os
os.chdir('/mnt/lustre/helios-home/stadnkyr')
print(os.getcwd())

import logging
from typing import Tuple
import pickle
import datetime

import pandas as pd
import numpy as np
from sksurv.column import encode_categorical
from sksurv.column import standardize
from sksurv.util import Surv

# from columns import COLUMNS
from surv_data_pipeline.columns import COLUMNS
from sklearn.model_selection import train_test_split
from sksurv.ensemble import RandomSurvivalForest

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sksurv.preprocessing import OneHotEncoder as SurvOneHotEncoder
from sksurv.util import Surv

logger = logging.getLogger(__name__)
logger.setLevel(logging.DEBUG)
console_handler = logging.StreamHandler()
console_handler.setLevel(logging.DEBUG)
formatter = logging.Formatter('%(asctime)s - %(message)s')
console_handler.setFormatter(formatter)
logger.addHandler(console_handler)

/mnt/lustre/helios-home/stadnkyr


In [18]:
class ScikitSurvivalDataLoader:
    df = None
    target = None

    ''' ideas
     - return already smaller dataset with df = df.sample(frac=1).reset_index(drop=True) (draws sample percentage of the dataset)
     - return another variable - validation set (or do that in code itself)
     - add self.validation_set - without filling nan values with median
     - try to train only on newer transplantations - today the chances of survival are much higher than it was in 80s.
     - try to add kidney compatibility indices
     - make defining of the features only here and not to use columns.py
    '''
    
    # target_columns = ["PTIME", "PSTATUS"]

    yes_categorical=['ON_DIALYSIS',"PRE_TX_TXFUS",  "GENDER", "ETHCAT", 
                     "DIABETES_DON",
                     'DIAB',
                    'HCV_SEROSTATUS', 
                    "AGE_GROUP", "DON_TY", "COD_KI", #"CONTIN_CIG_DON" #"DIET_DON",
                    
                ]
    
    yes_numerical = ["PTIME","PSTATUS", 'AGE', "BMI_CALC", 
                     "AGE_DON", 
                     "CREAT_TRR",
                     "NPKID",
                    'COLD_ISCH_KI', 
                    "CREAT_DON",
                    "KDPI",
                    # "GFR",
                    "TX_DATE", "DIALYSIS_DATE","KDRI_RAO"
                    ]

    def __init__(self, patient_survival: bool = True) -> None:
        self.target = ["PTIME", "PSTATUS"]

        self.yes_numerical.insert(0,self.target[1])
        self.yes_numerical.insert(0,self.target[0])

    def load(self, n_samples_to_load=None, fill_na_with_median:bool = True) -> Tuple[pd.DataFrame, any]:
        self._load_pd_df()
        self._apply_exclusion_criteria()

        if n_samples_to_load is not None:
            self.df = self.df.sample(n_samples_to_load, random_state=42)

        # self._divide_train_test_validation()

        self.df = self._handle_nan(fill_na_with_median)
        # self.test = self._handle_nan(fill_na_with_median, self.test)
        # self.validate = self._handle_nan(fill_na_with_median, self.validate)

        return self._get_X_y()
    
    def _load_pd_df(self, columns = COLUMNS):
        logger.info("Loading data into pandas DataFrame...")
        self.df = pd.read_parquet(
            "/mnt/lustre/helios-home/stadnkyr/Kidney_transplants.parquet", 
            engine='auto', columns=list(set(self.yes_categorical+self.yes_numerical)))

        logger.info(f"Done! Loaded df of shape {self.df.shape}")
    
    def get_test_X_y(self):
        return self._get_X_y(self.test)

    def get_validate_X_y(self):
        return self._get_X_y(self.validate)
    
    def _divide_train_test_validation(self):
        self.test = None
        self.validate = None

        logger.info("Dividing the dataset into train, test and validation sets...")

        self.df, test_validation_df = train_test_split(self.df, test_size=0.3, random_state=42)

        self.test, self.validate = train_test_split(test_validation_df, test_size=0.5, random_state=42)

        logger.info("Done!")

    
    def _apply_exclusion_criteria(self):
        '''
        ideas: try to censor unrelated reasons
        '''
        UNRELATED_COD = {998,999, 2801,2803,8065,8064,8063,8062,8050,7237, 7226, 
                         7227, 6853,5808,3899, 3800}
        
        self.df = self.df[self.df['PSTATUS'].notnull()]
        self.df = self.df[self.df['AGE_GROUP'] == "A"]
        self.df = self.df[self.df['DON_TY'] != "F"]
        self.df = self.df[self.df['DIABETES_DON'] != "U"]
        self.df = self.df[self.df['PRE_TX_TXFUS'] != "U"]
        self.df = self.df[self.df['HCV_SEROSTATUS'] != "U"]
        self.df = self.df[self.df['ETHCAT'] != "998"]
        self.df = self.df[self.df['DIAB'] != "998"]
        self.df = self.df[self.df['ETHCAT'] != "9"]

        self.df = self.df[self.df['DON_TY'] == "C"] # deceased donor specification

        # remove death of unrelated  reasons
        mask = ~self.df['COD_KI'].isin(UNRELATED_COD)
        self.df = self.df[mask]
        self.df= self.df.drop('COD_KI', axis=1)

        self.df = self.df.drop('DON_TY', axis=1)
        self.df = self.df.drop('AGE_GROUP', axis=1)

        self.yes_categorical = [x for x in self.yes_categorical if x != 'DON_TY' and x != 'AGE_GROUP']

        self.df['KDPI'] = pd.to_numeric(self.df['KDPI'].str.replace('%', ''))

        # self.df.dropna(subset=['DIALYSIS_DATE', 'TX_DATE'], inplace=True)
        # self.df['DIALYSIS_TIME'] = self.df.apply(lambda row: self._get_difference_in_days(row['DIALYSIS_DATE'], row['TX_DATE']), axis=1)
        self.df['DIALYSIS_TIME'] = self.df.apply(lambda row: 
                                                 self._get_difference_in_days(row['DIALYSIS_DATE'], row['TX_DATE']) 
                                                 if row['DIALYSIS_DATE'] is not None and row['TX_DATE'] 
                                                 is not None else 0, axis=1)

        self.yes_numerical = [x for x in self.yes_numerical if x != 'DIALYSIS_DATE' and x != "TX_DATE"]
        self.yes_numerical.append('DIALYSIS_TIME')

        logger.info(f"{self.df.shape}")

    def _handle_nan(self, fill_na_with_median:bool, dataset = None):
        '''
        evaluation set should not have values that were calculated
        '''
        if dataset is None:
            dataset = self.df
            fill_na_with_median=False
        else:
            fill_na_with_median=False

        if fill_na_with_median:
            logger.info("Handling nan values...")
        else:
            logger.info("Dropping nan values...")
        
        self.yes_categorical = [item for item in self.yes_categorical if item !="COD_KI"]
        dataset.dropna(subset=list(set(self.yes_categorical+self.yes_numerical)), inplace=True)

        dataset.dropna(subset=self.yes_numerical, inplace=True)
        
        # print(dataset.shape)
        logger.info("Done!")
        return dataset

    def _get_X_y(self, dataset = None) -> Tuple[pd.DataFrame, any]:
        if dataset is None:
            dataset = self.df
            
        logger.info("Dividing data into X and y...")

        numeric_features = [x for x in self.yes_numerical if x != "PTIME" and x != "PSTATUS"]
        categorical_features = [x for x in self.yes_categorical if x != "PSTATUS"]

        dataset[numeric_features] = dataset[numeric_features].astype('float64')

        with open('pickle/DATA_DECEASED.pkl', 'wb') as file:
            pickle.dump(dataset, file)

        # Define transformations for numeric and categorical features
        numeric_transformer = Pipeline(steps=[
            ('scaler', StandardScaler())
        ])

        categorical_transformer = Pipeline(steps=[
            ('encoder', OneHotEncoder(handle_unknown='ignore'))
        ])

        # Combine transformations for all features
        preprocessor = ColumnTransformer(
            transformers=[
                ('num', numeric_transformer, numeric_features),
                ('cat', categorical_transformer, categorical_features)
            ]
        )

        # Set up the final pipeline
        pipeline = Pipeline(steps=[
            ('preprocessor', preprocessor)
        ])

        # Apply preprocessing to X
        X = pipeline.fit_transform(dataset[numeric_features + categorical_features])
        # print(dataset[numeric_features + categorical_features].columns)

        categorical_x = encode_categorical(dataset[categorical_features])
        numerical_x = standardize(dataset[numeric_features])
        X = pd.concat([numerical_x, categorical_x], axis=1)
        
        survival_time = dataset[self.target[0]].astype(np.float64)
        event = dataset[self.target[1]].astype(float).astype(np.bool)

        y = Surv.from_arrays(event, survival_time, "Status", "Days")

        self.df = None

        logger.info("Done!")

        with open('pickle/trained_pipeline.pkl', 'wb') as file:
            pickle.dump(pipeline, file)

        return X, y
    
    def _calculate_egfr(self, creatinine, age, gender, race):
        pass
    
    def _get_difference_in_days(self, date1, date2):
        date1_dict = eval(str(date1))
        date2_dict = eval(str(date2))
        
        date1 = date1_dict['$date']
        date2 = date2_dict['$date']

        if isinstance(date1, str) and (date2, str):
            date1_object = datetime.datetime.strptime(date1, "%Y-%m-%dT%H:%M:%SZ")
            date2_object = datetime.datetime.strptime(date2, "%Y-%m-%dT%H:%M:%SZ")

            difference = date2_object - date1_object
            return difference.days
        else:
            return



In [19]:
loader = ScikitSurvivalDataLoader()

X, y = loader.load()

2024-02-28 12:32:18,119 - Loading data into pandas DataFrame...
2024-02-28 12:32:40,815 - Done! Loaded df of shape (993806, 23)
2024-02-28 12:33:03,661 - (314517, 21)
2024-02-28 12:33:03,663 - Dropping nan values...
2024-02-28 12:33:04,262 - Done!
2024-02-28 12:33:04,264 - Dividing data into X and y...
2024-02-28 12:33:05,134 - Done!


In [20]:
X.shape

(117536, 26)

In [None]:
print(y)

[( True, 3501.) ( True, 6479.) (False, 5460.) ... ( True,  739.)
 ( True, 4973.) ( True, 1239.)]


In [21]:
# stratify = y["event"].astype(int)
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y["Status"], random_state=42)

In [6]:
from sksurv.metrics import (
    concordance_index_censored,
    concordance_index_ipcw,
    cumulative_dynamic_auc,
    integrated_brier_score,
)

def evaluate_model(model, test_X, test_y, train_y, times):
    pred = model.predict(test_X)
    surv_fn = model.predict_survival_function(test_X, return_array=False)
    surv_prob = np.row_stack([fn(times) for fn in surv_fn])

    uno_concordance = concordance_index_ipcw(train_y, test_y, pred, tau=times[-1])
    ibs = integrated_brier_score(train_y, test_y, surv_prob, times)
    auc, mean_auc = cumulative_dynamic_auc(train_y, test_y, pred, times)

    print(f"Concordance Uno: {round(uno_concordance[0], 3)}")
    print(f"IBS: {round(ibs, 3)}")
    print(f"Mean AUC: {round(mean_auc,3)}")

    return uno_concordance, ibs, mean_auc

# Saving the best model's time dependent-evaluation

In [7]:
import pickle

lower, upper = np.percentile(y["Days"], [10, 90])
times = np.arange(lower, upper + 1)

with open('pickle/models/RSF_DECEASED_0.69.pkl', 'rb') as file:
    best_model = pickle.load(file)

# _,_,_ = evaluate_model(best_model, X_test, y_test, y_train, times)

In [None]:
import matplotlib.pyplot as plt
from sksurv.metrics import brier_score

surv_fn = best_model.predict_survival_function(X_test, return_array=False)

surv_prob = np.row_stack([fn(times) for fn in surv_fn])

bs = brier_score(y_train, y_test, surv_prob, times)

# plt.figure(figsize=(10,6))
# plt.plot(bs[0], bs[1], marker=",")
# # plt.axhline(ibs, linestyle="--")
# # plt.text(5, 0, "{model}", fontsize=12)
# plt.title("Time-dependent Brier Score for the Random Survival Forest")
# plt.xlabel("days")
# plt.ylabel("time-dependent Brier Score")

In [11]:
# with open('pickle/models/RSF_DECEASED_FINAL_BRIER_SCORE.pkl', 'wb') as file:
#     pickle.dump(bs, file)

In [None]:
lower, upper = np.percentile(y_train["Days"], [10, 90])
times = np.arange(lower, upper + 1)
cph_risk_scores = best_model.predict(X_test)
auc, mean_auc = cumulative_dynamic_auc(y_train, y_test, cph_risk_scores, times)


In [13]:
# with open('pickle/models/AUC_RSF_DECEASED_FINAL.pickle', 'wb') as f:
#     pickle.dump((times, auc, mean_auc), f)

# Feature Importance

In [None]:
rsf = RandomSurvivalForest(n_estimators=3, n_jobs=-1, random_state=42)
rsf.fit(X_train[1000:], y_train[1000:])

RandomSurvivalForest(n_estimators=3, n_jobs=-1, random_state=42)

In [22]:
from sklearn.inspection import permutation_importance

# uncoment if you want to calculate permutation importance (data must not be processed by pipeline)
result = permutation_importance(best_model, X_test, y_test, n_repeats=10, random_state=0, n_jobs=1)
# result = permutation_importance(rsf, X_test, y_test, n_repeats=10, random_state=0, n_jobs=1)

In [23]:
pd.set_option('display.max_rows', None)

# columns = numeric_features + categorical_features

importances_df = pd.DataFrame(result.importances_mean, index=X_train.columns)
importances_df.columns = ['Importance']
importances_df.sort_values(by='Importance', ascending=False, inplace=True)

# Print out feature importances
print(importances_df)

                   Importance
AGE                  0.100619
DIAB=5.0             0.018903
DIAB=3.0             0.017945
DIALYSIS_TIME        0.005100
KDRI_RAO             0.004415
DIAB=2.0             0.004288
HCV_SEROSTATUS=P     0.003187
ON_DIALYSIS=Y        0.002347
CREAT_TRR            0.002001
GENDER=M             0.000861
ETHCAT=4             0.000860
BMI_CALC             0.000802
ETHCAT=5             0.000795
AGE_DON              0.000680
PRE_TX_TXFUS=Y       0.000590
NPKID                0.000380
ETHCAT=2             0.000280
COLD_ISCH_KI         0.000219
KDPI                 0.000021
DIAB=4.0             0.000013
CREAT_DON            0.000011
DIAB=998.0           0.000011
HCV_SEROSTATUS=ND    0.000008
ETHCAT=6             0.000008
ETHCAT=7            -0.000001
DIABETES_DON=Y      -0.000023


In [None]:
plt_importances =importances_df.sort_values(by='Importance', ascending=True)
plt_importances.plot.barh(color='blue', legend=False, title='RSF Feature Permutation Importance Living', grid=True, figsize=(8, 9))

                     Importance
AGE                7.603468e-02
DIAB=5.0           1.755682e-02
DIAB=3.0           1.636021e-02
CREAT_TRR          6.719912e-03
ON_DIALYSIS=Y      6.290068e-03
AGE_DON            4.584326e-03
KDPI               4.224856e-03
DIALYSIS_TIME      3.907660e-03
KDRI_RAO           3.656396e-03
HCV_SEROSTATUS=P   3.371945e-03
DIAB=2.0           3.240055e-03
ETHCAT=4           2.897624e-03
BMI_CALC           2.440851e-03
ETHCAT=5           2.209945e-03
PRE_TX_TXFUS=Y     1.797013e-03
ETHCAT=2           1.152406e-03
COLD_ISCH_KI       1.093853e-03
NPKID              4.144752e-04
HCV_SEROSTATUS=ND  2.775765e-04
GENDER=M           1.363486e-04
ETHCAT=6           6.727253e-05
CREAT_DON          1.390463e-05
ETHCAT=7           1.339929e-07
DIAB=4.0          -3.417864e-05
DIAB=998.0        -3.806376e-05
DIABETES_DON=Y    -1.558873e-04

# RSF train

In [None]:
rsf = RandomSurvivalForest(n_estimators=50, n_jobs=-1, max_depth=12, min_samples_split=16, max_features=None, random_state=42)
rsf.fit(X_train, y_train)

In [None]:
evaluate_model(rsf, X_test, y_test, y_train)

In [None]:
from sksurv.metrics import concordance_index_ipcw

lower, upper = np.percentile(y["Days"], [10, 90])
times = np.arange(lower, upper + 1)

pred = rsf.predict(X_test)
uno_concordance = concordance_index_ipcw(y_train[1000:], y_test, pred, tau=times[-1])
uno = float(uno_concordance[0])
uno

0.6908782202813217

In [None]:
# import pickle

# with open('pickle/models/RSF_DECEASED_0.69.pkl', 'wb') as file:
#             pickle.dump(rsf, file)

# RSF fine-tuning

In [None]:
from sklearn.model_selection import train_test_split
from surv_data_pipeline.estimator_evaluation import SurvivalEstimatorEvaluation

from sksurv.metrics import (
    concordance_index_censored,
    concordance_index_ipcw,
    cumulative_dynamic_auc,
    integrated_brier_score,
    brier_score
)
import matplotlib.pyplot as plt


def train_model(model, x, y):
    lower, upper = np.percentile(y["Days"], [10, 90])
    times = np.arange(lower, upper + 1)

    X_train, X_test, y_train, y_test = train_test_split(x, y, test_size=.2, random_state=42)

    model.fit(X_train, y_train)
    # model.fit(X_train, y_train)

    # score = model.score(X_test, y_test)
    # print("Model's score: ", round(score, 3))

    uno_score = SurvivalEstimatorEvaluation.evaluate_model_uno_c(model, X_test, y_test, y_train, times)
    # print("Uno's score: ", round(uno_score[0], 3))
    # auc_score, mean_auc = SurvivalEstimatorEvaluation.evaluate_model_auc(model, X_test.iloc[:5000], y_test[:5000], y, times)
    # print("Mean AUC: ", round(mean_auc, 3))

    # surv = model.predict_survival_function(X_test)#, return_array=False)
    # surv_prob = np.row_stack([fn(times) for fn in surv])

    # ibs = integrated_brier_score(y_train, y_test, surv_prob, times)#best_rsf.event_times_)
    # print('Integrated Brier Score:', round(ibs,3))

    return model, uno_score#, ibs

In [None]:
from sksurv.ensemble import RandomSurvivalForest
from tqdm import tqdm

# n_estimators = [70, 80, 90]
# max_depth = [5,7]
# min_samples_split = [10, 12, 14]
# max_features = [None]

n_estimators = [100, 200]
max_depth = [8, 12, None]
min_samples_split = [12]
max_features = [ None]

best_params = None
# lowest_ibs = 1
highest_cindex = 0
best_rsf_model = None



rsf_gr = RandomSurvivalForest(n_jobs=-1)
pbar = tqdm(total = len(n_estimators)*len(max_depth)*len(min_samples_split)
            *len(max_features), desc='Hyperparameter Tuning')

for n in n_estimators:
    for depth in max_depth:
        for min_split in min_samples_split:
            for max_feat in max_features:
                rsf_gr.set_params(n_estimators=n, 
                                  max_depth=depth, 
                                  min_samples_split=min_split, 
                                  max_features=max_feat)
                rsf_gr, uno = train_model(rsf_gr, 
                                          X.iloc[100:], y[100:])
                                        #   X, y)

                uno = float(uno[0])
                if uno > highest_cindex:
                    highest_cindex = uno
                    best_params = (n, depth, min_split, max_feat)
                    best_rsf_model = rsf_gr

                # Update the progress bar
                pbar.update()

# Close the progress bar
pbar.close() 

print(best_params, highest_cindex)

In [None]:
best_rsf = RandomSurvivalForest(n_jobs=-1,
                              n_estimators=100,
                              max_depth=None,
                              min_samples_split=15,
                              max_features='sqrt',
                               verbose=1
                              )
best_rsf.fit(X_train, y_train)

In [None]:
SurvivalEstimatorEvaluation.evaluate_model_uno_c(best_rsf, X_test.iloc[1000:], 
                                                 y_test[1000:], y_train, times)