In [1]:
import pickle
import datetime
from pathlib import Path
import yaml
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored
from sksurv.util import Surv
from sklearn.inspection import permutation_importance
import optuna
import shap
from explainability import SHAP
from evaluation import evaluate_survival_model, PartialLogLikelihood
from training_survival_analysis import train_model
from models import MinimalisticNetwork
import matplotlib.pyplot as plt
import json

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset

In [2]:
def prepare_rsf_data(df, time_col, event_col, random_state=1234, test_size=0.1, n_splits=5):
    """
    RSF 모델 학습을 위한 데이터 준비 함수.
    
    Args:
        df: 전처리된 데이터프레임 (결측치 없는 상태)
        time_col: 생존 시간을 나타내는 컬럼명 (ex: 'fu_total_yr')
        event_col: 생존 여부를 나타내는 컬럼명 (ex: 'survival')
        random_state: 재현성을 위한 랜덤 시드
        test_size: Train / test set split 비율. 0.1 디폴트
        n_splits: cross validation fold 수. 5 디폴트

    Returns:
        X_train, X_test, y_train, y_test, kfold
    """
    # 독립변수(X)에서 time_col, event_col 제거 후 One-Hot Encoding
    X = df.drop(columns=[time_col, event_col])
    X = pd.get_dummies(X, drop_first=True)

    # 종속변수(y)를 구조화 배열로 변환
    y = np.zeros(df.shape[0], dtype=[('vit_status', '?'), ('survival_time', '<f8')])
    y['vit_status'] = df[event_col].values.astype(bool)
    y['survival_time'] = df[time_col].values.astype(float)
    np.random.seed(random_state)

    # 데이터 분할 (Train/Test Split)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # KFold 설정
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    return X_train, X_test, y_train, y_test, kfold

# 함수 과정 : X & y 생성, RSF 학습 위한 One-Hot Encoding, train/test split, fold 정의.

In [3]:
def optimize_rsf(X_train, y_train, kfold, random_state=1234, n_trials=50):
    """
    Optuna를 사용해 Random Survival Forest의 최적 하이퍼파라미터를 찾는 함수.

    Args:
        X_train: 훈련 데이터 (독립변수)
        y_train: 훈련 데이터 (종속변수)
        kfold: K-Fold 객체
        random_state: 랜덤 시드
        n_trials: Optuna 하이퍼파라미터 서치 횟수

    Returns:
        best_params: 최적의 하이퍼파라미터
    """
    np.random.seed(random_state)
    # 설정 파일 로드
    config = yaml.safe_load(Path("./config.yaml").read_text())
    rsf_config = config["rsf"]
    selected_features = X_train.columns
    rsf_config["max_features"]["max"] = len(selected_features)
    model_name = "rsf"

    def objective_rsf(trial: optuna.Trial):
        """Optuna에서 사용할 목적 함수"""
        params = {
            "n_estimators": trial.suggest_int("n_estimators", rsf_config["n_estimators"]["min"], rsf_config["n_estimators"]["max"]),
            "min_samples_leaf": trial.suggest_int("min_samples_leaf", rsf_config["min_samples_leaf"]["min"], rsf_config["min_samples_leaf"]["max"]),
            "max_features": trial.suggest_int("max_features", rsf_config["max_features"]["min"], rsf_config["max_features"]["max"]),
            "max_depth": trial.suggest_int("max_depth", rsf_config["max_depth"]["min"], rsf_config["max_depth"]["max"]),
        }
        scores = []
        for train_idx, _ in kfold.split(X_train, y_train):
            X_fold = X_train.iloc[train_idx]
            model = RandomSurvivalForest(n_estimators=params["n_estimators"],
                                        min_samples_leaf=params["min_samples_leaf"],
                                        max_features=params["max_features"],
                                        max_depth=params["max_depth"],
                                        random_state=random_state, n_jobs=-1)
            model.fit(X_fold, y_train[train_idx])
            score = model.score(X_fold, y_train[train_idx])
            scores.append(score)
        return np.mean(scores)

    # Optuna 실행
    study = optuna.create_study(study_name=model_name+str(datetime.datetime.now()),
                                direction="maximize",
                                sampler=optuna.samplers.TPESampler(seed=random_state))
    study.optimize(objective_rsf, n_trials=n_trials)
    best_params = study.best_trial.params

    return best_params

In [4]:
def train_and_evaluate_rsf(X_train, X_test, y_train, y_test, best_params, kfold, random_state=1234):
    """
    Random Survival Forest를 학습하고 평가하는 함수.
    
    Args:
        X_train, X_test, y_train, y_test: 훈련 및 테스트 데이터
        best_params: Optuna에서 찾은 최적 하이퍼파라미터
        kfold: K-Fold 객체
        random_state: 랜덤 시드
    
    Returns:
        각 fold별 feature importance : .csv 파일로 저장, SHAP values beeswarm plot : .png파일로 저장 
        fi: K-Fold 결과 DataFrame
    """
    np.random.seed(random_state)
    # 각 fold의 결과를 저장할 딕셔너리 생성
    fold_scores = {}
    
    # 최종 모델 학습 및 결과 저장
    for i, (train_fold, val_fold) in enumerate(kfold.split(X_train, y_train)):
        X_train_fold = X_train.iloc[train_fold]
        X_val_fold = X_train.iloc[val_fold]
        best_model = RandomSurvivalForest(**best_params, random_state=random_state, n_jobs=-1)
        best_model.fit(X_train_fold, y_train[train_fold])
        
        # 평가 결과 저장
        scores = evaluate_survival_model(best_model, X_val_fold, y_train[train_fold], y_train[val_fold])
        print(f"Final RSF Scores in Fold {i}: {scores}")
        fold_scores[f"fold_{i}"] = scores  
    
        # ---- Permutation Importance 저장 ----
        result = permutation_importance(best_model, X_val_fold, y_train[val_fold], n_repeats=15, random_state=random_state)
        result_dict = {k: result[k] for k in ("importances_mean", "importances_std")}
        permutation_importances = pd.DataFrame(result_dict, index=X_val_fold.columns).sort_values(by="importances_mean", ascending=False)
    
        perm_imp_path = f"RSF_permutation_importances_fold_{i}.csv"
        permutation_importances.to_csv(perm_imp_path, encoding = 'utf-8')
        print(f"Permutation importances saved to {perm_imp_path}")
        
        # ---- SHAP values 저장 ----
        explainer = shap.Explainer(best_model.predict, X_val_fold)
        shap_values = explainer(X_val_fold)
        
        # ---- SHAP Beeswarm Plot 저장 ----
        plt.figure(figsize=(10, 6))  # 적절한 크기 설정
        shap.plots.beeswarm(shap_values, show=False)
        
        beeswarm_path = f"RSF_beeswarm_fold_{i}.png"
        plt.tight_layout()  # 여백 자동 조정
        plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
        plt.close()  # 메모리 정리
        print(f"Beeswarm plot saved to {beeswarm_path}")
        
    # Final Scores
    fold_scores = pd.DataFrame(fold_scores).T
    final_scores = fold_scores.mean(skipna=True)
    print(final_scores)

    return final_scores

In [5]:
############# 전처리 ################

df = pd.read_csv("processed_survival_data_modified.csv") # processed_survival_data_modified : Age 범주화한 데이터셋
df["implant_length_group"] = df["implant_length_group"].apply(
    lambda x: "Length ≥ 10" if x == "길이 10 이상" else "Length < 10"
)
df["survival"] = df["survival"].map({"survive": 0, "fail": 1}) # 종속변수(수술 성공 여부)를 0, 1로 변환
# 분석 제외할 변수 제거
exclude_columns = ["patient_ID", "me", "failure_reason", "failure_date", 
                   "last_fu_date", "surgery_Date", "fu_for_fail_yr", "fu_for_survival_yr"]
all_columns = [col for col in df.columns if col not in exclude_columns] # 분석에 사용할 변수만 포함
# 지정한 컬럼들에 결측치가 있는 행 제거
df = df.dropna(subset = all_columns)
df = df[all_columns]
selected_features = [col for col in df.columns if col not in ['fu_total_yr', 'survival'] + exclude_columns]
time_col = "fu_total_yr"
event_col = "survival"

In [6]:
########## 최종 실행 코드 #############
X_train, X_test, y_train, y_test, kfold = prepare_rsf_data(df, "fu_total_yr", "survival")
best_params = optimize_rsf(X_train, y_train, kfold)
fold_scores_df = train_and_evaluate_rsf(X_train, X_test, y_train, y_test, best_params, kfold)
print(fold_scores_df)

[I 2025-03-09 11:50:43,580] A new study created in memory with name: rsf2025-03-09 11:50:43.579110
[I 2025-03-09 11:50:45,739] Trial 0 finished with value: 0.8655542072431187 and parameters: {'n_estimators': 199, 'min_samples_leaf': 21, 'max_features': 9, 'max_depth': 12}. Best is trial 0 with value: 0.8655542072431187.
[I 2025-03-09 11:50:56,914] Trial 1 finished with value: 0.8954734093203219 and parameters: {'n_estimators': 782, 'min_samples_leaf': 12, 'max_features': 6, 'max_depth': 13}. Best is trial 1 with value: 0.8954734093203219.
[I 2025-03-09 11:51:10,436] Trial 2 finished with value: 0.841404433219548 and parameters: {'n_estimators': 959, 'min_samples_leaf': 27, 'max_features': 7, 'max_depth': 9}. Best is trial 1 with value: 0.8954734093203219.
[I 2025-03-09 11:51:19,805] Trial 3 finished with value: 0.8552077615387697 and parameters: {'n_estimators': 687, 'min_samples_leaf': 23, 'max_features': 7, 'max_depth': 9}. Best is trial 1 with value: 0.8954734093203219.
[I 2025-03-0

Final RSF Scores in Fold 0: {'c_index': 0.5951156812339332, 'mean_auc': nan, 'ibs': 0.06651475310367631}
Permutation importances saved to RSF_permutation_importances_fold_0.csv


PermutationExplainer explainer: 99it [04:43,  2.92s/it]                        


Beeswarm plot saved to RSF_beeswarm_fold_0.png
Final RSF Scores in Fold 1: {'c_index': 0.7050691244239631, 'mean_auc': 0.7196028308103923, 'ibs': 0.025553780533007896}
Permutation importances saved to RSF_permutation_importances_fold_1.csv


PermutationExplainer explainer: 99it [04:38,  2.93s/it]                        


Beeswarm plot saved to RSF_beeswarm_fold_1.png
Final RSF Scores in Fold 2: {'c_index': 0.6112099644128114, 'mean_auc': 0.6663719885597116, 'ibs': 0.050266213564786104}
Permutation importances saved to RSF_permutation_importances_fold_2.csv


PermutationExplainer explainer: 99it [04:42,  2.97s/it]                        


Beeswarm plot saved to RSF_beeswarm_fold_2.png
Final RSF Scores in Fold 3: {'c_index': 0.7807971014492754, 'mean_auc': 0.8382614629947835, 'ibs': 0.03977933639701428}
Permutation importances saved to RSF_permutation_importances_fold_3.csv


PermutationExplainer explainer: 98it [04:36,  2.94s/it]                        


Beeswarm plot saved to RSF_beeswarm_fold_3.png
Final RSF Scores in Fold 4: {'c_index': 0.9134808853118712, 'mean_auc': 0.9445427988183446, 'ibs': 0.039819320678563246}
Permutation importances saved to RSF_permutation_importances_fold_4.csv


PermutationExplainer explainer: 98it [04:37,  2.95s/it]                        


Beeswarm plot saved to RSF_beeswarm_fold_4.png
c_index     0.721135
mean_auc    0.792195
ibs         0.044387
dtype: float64
c_index     0.721135
mean_auc    0.792195
ibs         0.044387
dtype: float64
