In [7]:
import pandas as pd
import numpy as np

import sksurv
print(sksurv.__version__)

from sksurv.util import Surv
from sksurv.ensemble import RandomSurvivalForest
from sksurv.svm import FastSurvivalSVM

from sklearn.model_selection import GridSearchCV
from sklearn.model_selection import KFold
from sklearn.preprocessing import StandardScaler
from sksurv.metrics import concordance_index_censored as sksurv_c_index

from lifelines import CoxPHFitter, AalenAdditiveFitter
from lifelines.utils import concordance_index as lifelines_c_index

from xgboost import XGBClassifier

0.23.1


In [11]:
file_path = './simulated_survival_data'
X_scaled = pd.read_csv(f'{file_path}/X_scaled.csv')

train_df_control = pd.read_csv(f'{file_path}/control_train.csv', index_col=0)
X_train = train_df_control.iloc[:, :20] #gene exp
Y_train = train_df_control.iloc[:, 20:22] #survival time and event

print(X_train.shape, Y_train.shape)

invalid_count = (Y_train['time'] <= 0).sum()
valid_mask = Y_train['time'] > 0

X_filtered = X_train.loc[valid_mask].copy()
y_filtered = Y_train.loc[valid_mask].copy()

X = X_filtered
y = y_filtered

print(X.shape, y.shape)

y_structured = Surv.from_arrays(
    event=y_filtered['event'].values.astype(bool),
    time=y_filtered['time'].values
)

lifelines_df = pd.concat([X, y], axis=1)

(640, 20) (640, 2)
(636, 20) (636, 2)


In [12]:
models_and_params = {
    'RandomForest': {
        'model': RandomSurvivalForest(random_state=42),
        'param_grid': {
            'n_estimators': [100, 200],
            'max_depth': [5, 10],
            'min_samples_leaf': [15, 25]
        }
    },
    'SVM': {
        'model': FastSurvivalSVM(random_state=42),
        'param_grid': {
            'alpha': [1e-3, 1e-2],
            'max_iter': [1000]
        }
    },
    'XGBoost': {
        'model': XGBClassifier(random_state=42, use_label_encoder=False, eval_metric='logloss'),
        'param_grid': {
            'n_estimators': [100, 200],
            'max_depth': [3, 6]
        }
    }
}

In [None]:
num_repeats = 20
k_folds = 5

c_index_results = {name: [] for name in models_and_params.keys()}
c_index_results['CoxPH'] = []
c_index_results['Aalen'] = []

for run in range(num_repeats):
    print(f"\n--- No. {run + 1}/{num_repeats}  ---")
    kf = KFold(n_splits=k_folds, shuffle=True, random_state=run)
    
    for model_name, config in models_and_params.items():
        print(f"fintune and evaluate: {model_name}")
        
        inner_kf = KFold(n_splits=k_folds, shuffle=True, random_state=run)
        gscv = GridSearchCV(config['model'], config['param_grid'], cv=inner_kf, n_jobs=-1)
        
        fold_c_indices = []
        for train_idx, test_idx in kf.split(X_scaled, y_structured):
            X_train, X_test = X_scaled.iloc[train_idx], X_scaled.iloc[test_idx]
            y_train, y_test = y_structured[train_idx], y_structured[test_idx]
            
            if model_name == 'XGBoost':
                # Pass only the event array, not the full structured array
                gscv.fit(X_train, y_train['event'])
            else:
                # For survival models (RSF, SVM), use the structured array
                gscv.fit(X_train, y_train)
            
            best_model = gscv.best_estimator_
            
            if model_name == 'XGBoost':
                # XGBoost as a classifier; C-index needs event and time
                predictions = best_model.predict_proba(X_test)[:, 1]
                c_index = sksurv_c_index(
                    event_indicator=y_test['event'],
                    event_time=y_test['time'],
                    estimate=predictions
                )[0]
            else:
                predictions = best_model.predict(X_test)
                # Corrected call: Pass all three required arguments explicitly
                c_index = sksurv_c_index(
                    event_indicator=y_test['event'],  # The 'event' part of the structured array
                    event_time=y_test['time'],      # The 'time' part of the structured array
                    estimate=predictions
                )[0]

            fold_c_indices.append(c_index)
        
        if np.isnan(np.mean(fold_c_indices)):
            raise ValueError(f"XGBoost C-index resulted in NaN. This indicates an issue in the calculation.")
        
        print(f'finish {run + 1}/{num_repeats} {model_name}')    
        c_index_results[model_name].append(np.mean(fold_c_indices))
    
    # --- 评估 Cox 和 Aalen 模型 (无需调参) ---
    print("evaluate CoxPH...")
    cox_scores = []
    for train_idx, test_idx in kf.split(lifelines_df):
        train_df = lifelines_df.iloc[train_idx]
        test_df = lifelines_df.iloc[test_idx]
        
        cph = CoxPHFitter()
        cph.fit(train_df, duration_col='time', event_col='event')
        
        predictions = cph.predict_partial_hazard(test_df)
        c_index = lifelines_c_index(test_df['time'], predictions, test_df['event'])
        cox_scores.append(c_index)
    print(f'finish {run + 1}/{num_repeats} CoxPH')
    c_index_results['CoxPH'].append(np.mean(cox_scores))

    print("evaluate Aalen...")
    aalen_scores = []
    for train_idx, test_idx in kf.split(lifelines_df):
        train_df = lifelines_df.iloc[train_idx]
        test_df = lifelines_df.iloc[test_idx]
        
        aaf = AalenAdditiveFitter()
        aaf.fit(train_df, duration_col='time', event_col='event')
        
        predictions_df = aaf.predict_survival_function(test_df)
        predictions = -predictions_df.iloc[0]
        
        c_index = lifelines_c_index(test_df['time'], predictions, test_df['event'])
        aalen_scores.append(c_index)
    print(f'finish {run + 1}/{num_repeats} Aalen')
    c_index_results['Aalen'].append(np.mean(aalen_scores))


--- No. 1/20  ---
fintune and evaluate: RandomForest
finish 1/20 RandomForest
fintune and evaluate: SVM
finish 1/20 SVM
fintune and evaluate: XGBoost
finish 1/20 XGBoost
evaluate CoxPH...
finish 1/20 CoxPH
evaluate Aalen...
finish 1/20 Aalen

--- No. 2/20  ---
fintune and evaluate: RandomForest
finish 2/20 RandomForest
fintune and evaluate: SVM
finish 2/20 SVM
fintune and evaluate: XGBoost
finish 2/20 XGBoost
evaluate CoxPH...
finish 2/20 CoxPH
evaluate Aalen...
finish 2/20 Aalen

--- No. 3/20  ---
fintune and evaluate: RandomForest
finish 3/20 RandomForest
fintune and evaluate: SVM
finish 3/20 SVM
fintune and evaluate: XGBoost
finish 3/20 XGBoost
evaluate CoxPH...
finish 3/20 CoxPH
evaluate Aalen...
finish 3/20 Aalen

--- No. 4/20  ---
fintune and evaluate: RandomForest
finish 4/20 RandomForest
fintune and evaluate: SVM
finish 4/20 SVM
fintune and evaluate: XGBoost
finish 4/20 XGBoost
evaluate CoxPH...
finish 4/20 CoxPH
evaluate Aalen...
finish 4/20 Aalen

--- No. 5/20  ---
fintune a

In [24]:
model_performance = {}
for model_name, scores in c_index_results.items():
    mean_score = np.mean(scores)
    std_score = np.std(scores)
    model_performance[model_name] = mean_score
    print(f"{model_name:<15}: average C-index = {mean_score:.4f}, SD = {std_score:.4f}")

# 找出表现最佳的模型
best_model_name = max(model_performance, key=model_performance.get)
print(f"\n the best model is: {best_model_name}, the average C-index is {model_performance[best_model_name]:.4f}")

RandomForest   : average C-index = 0.7663, SD = 0.0071
SVM            : average C-index = 0.8352, SD = 0.0141
XGBoost        : average C-index = 0.7624, SD = 0.0087
CoxPH          : average C-index = 0.1525, SD = 0.0016
Aalen          : average C-index = 0.3210, SD = 0.0076

 the best model is: SVM, the average C-index is 0.8352


In [26]:
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

plt.figure(figsize=(12, 7))
sns.violinplot(x='Model', y='C-index Score', data=c_index_results['Aalen'], palette='Paired')

TypeError: Data source must be a DataFrame or Mapping, not <class 'list'>.

<Figure size 1200x700 with 0 Axes>

In [27]:
c_index_results['Aalen']

[np.float64(0.3210507429354246),
 np.float64(0.3104631803810038),
 np.float64(0.31393250508518705),
 np.float64(0.32255369820541013),
 np.float64(0.3328201446317648),
 np.float64(0.3189149731054562),
 np.float64(0.3153881004166126),
 np.float64(0.31437230044764153),
 np.float64(0.334943043281681),
 np.float64(0.3274156359360302),
 np.float64(0.32901252338228426),
 np.float64(0.3151980756974564),
 np.float64(0.3151959139022626),
 np.float64(0.31578908481599965),
 np.float64(0.30754448365202436),
 np.float64(0.3296278033766177),
 np.float64(0.32433424290395646),
 np.float64(0.3175508225062155),
 np.float64(0.33055223793530536),
 np.float64(0.3237078505218781)]