In [1]:
import pickle
import pandas as pd
import numpy as np
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import StandardScaler, OneHotEncoder
from sksurv.preprocessing import OneHotEncoder as SurvOneHotEncoder
from sksurv.util import Surv

from sksurv.column import encode_categorical
from sksurv.column import standardize
from sksurv.util import Surv

from sklearn.model_selection import train_test_split
from sksurv.ensemble import RandomSurvivalForest
from sksurv.linear_model import CoxnetSurvivalAnalysis

from sksurv.metrics import (
    concordance_index_censored,
    concordance_index_ipcw,
    cumulative_dynamic_auc,
    integrated_brier_score,
)

def evaluate_model_uno_c(model, test_X, test_y, train_y, times):
    pred = model.predict(test_X)
    uno_concordance = concordance_index_ipcw(train_y, test_y, pred, tau=times[-1])
    return uno_concordance

In [4]:
pickle_file = '../data/DATA_DECEASED_ex.pkl'

with open(pickle_file, 'rb') as f:
    dataset = pickle.load(f)

In [5]:
numeric_features = [
    "AGE",
    "BMI_CALC",
    "AGE_DON",
    "CREAT_TRR",
    "NPKID",
    "COLD_ISCH_KI",
    "DIALYSIS_TIME",
    "KDPI",
]
categorical_features = [
    "ON_DIALYSIS",
    "PRE_TX_TXFUS",
    "GENDER",
    "ETHCAT",
    "DIABETES_DON",
    "DIAB",
    "HCV_SEROSTATUS",
]


In [6]:
numeric_transformer = Pipeline(steps=[
    ('scaler', StandardScaler())
])

categorical_transformer = Pipeline(steps=[
    ('encoder', OneHotEncoder(handle_unknown='ignore'))
])

# Combine transformations for all features
preprocessor = ColumnTransformer(
    transformers=[
        ('num', numeric_transformer, numeric_features),
        ('cat', categorical_transformer, categorical_features)
    ]
)

# Set up the final pipeline
pipeline = Pipeline(steps=[
    ('preprocessor', preprocessor)
])

# Apply preprocessing to X
# X = pipeline.fit_transform(dataset[categorical_features + numeric_features ])

categorical_x = encode_categorical(dataset[categorical_features])
numerical_x = standardize(dataset[numeric_features])
X = pd.concat([numerical_x, categorical_x], axis=1)

survival_time = dataset["PTIME"].astype(np.float64)
event = dataset["PSTATUS"].astype(float).astype(bool)

y = Surv.from_arrays(event, survival_time, "Status", "Days")

In [7]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, stratify=y["Status"], random_state=42)

In [None]:
rsf = RandomSurvivalForest(n_estimators=3, n_jobs=-1, random_state=42, low_memory=True)
rsf.fit(X_train[1000:], y_train[1000:])

In [None]:
lower, upper = np.percentile(y["Days"], [10, 90])
times = np.arange(lower, upper + 1)

# evaluate_model(rsf, X_test, y_test, y_train, times)
evaluate_model_uno_c(rsf, X_test[500:], y_test[500:], y_train[1000:], times)

In [None]:
from sklearn.inspection import permutation_importance

# uncoment if you want to calculate permutation importance (data must not be processed by pipeline)
result = permutation_importance(rsf, X_test, y_test, n_repeats=10, random_state=0, n_jobs=-1)

In [None]:
pd.set_option('display.max_rows', None)

# columns = numeric_features + categorical_features

importances_df = pd.DataFrame(result.importances_mean, index=X_train.columns)
importances_df.columns = ['Importance']
importances_df.sort_values(by='Importance', ascending=False, inplace=True)

# Print out feature importances
print(importances_df)

### Performed on helios
| Feature | Importance |
|---:|---:|
| AGE                | 7.603468e-02 |
| DIAB=5.0           | 1.755682e-02 |
| DIAB=3.0           | 1.636021e-02 |
| CREAT_TRR          | 6.719912e-03 |
| ON_DIALYSIS=Y      | 6.290068e-03 |
| AGE_DON            | 4.584326e-03 |
| KDPI               | 4.224856e-03 |
| DIALYSIS_TIME      | 3.907660e-03 |
| KDRI_RAO           | 3.656396e-03 |
| HCV_SEROSTATUS=P   | 3.371945e-03 |
| DIAB=2.0           | 3.240055e-03 |
| ETHCAT=4           | 2.897624e-03 |
| BMI_CALC           | 2.440851e-03 |
| ETHCAT=5           | 2.209945e-03 |
| PRE_TX_TXFUS=Y     | 1.797013e-03 |
| ETHCAT=2           | 1.152406e-03 |
| COLD_ISCH_KI       | 1.093853e-03 |
| NPKID              | 4.144752e-04 |
| HCV_SEROSTATUS=ND  | 2.775765e-04 |
| GENDER=M           | 1.363486e-04 |
| ETHCAT=6           | 6.727253e-05 |
| CREAT_DON          | 1.390463e-05 |
| ETHCAT=7           | 1.339929e-07 |
| DIAB=4.0           |-3.417864e-05 |
| DIAB=998.0         |-3.806376e-05 |
| DIABETES_DON=Y     |-1.558873e-04 |
