In [1]:
import pickle
import datetime
from pathlib import Path
import yaml
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored
from sksurv.util import Surv
from sklearn.inspection import permutation_importance
import optuna
import shap
from explainability import SHAP
from evaluation import evaluate_survival_model, PartialLogLikelihood
from training_survival_analysis import train_model
from models import MinimalisticNetwork
import matplotlib.pyplot as plt
import json

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset

In [2]:
# ------------------------------
# 데이터 로딩 및 전처리 (공통)
# ------------------------------
df = pd.read_csv("processed_survival_data_modified.csv") # processed_survival_data_modified : Age 범주화한 데이터셋
df["survival"] = df["survival"].map({"survive": 0, "fail": 1}) # 종속변수(수술 성공 여부)를 0, 1로 변환

# 코드 통일성을 위해 생존기간, 성공여부 변수 이름 변경
time_col = "fu_total_yr"
event_col = "survival"
df.rename(columns={event_col: "vit_status", time_col: "survival_time"}, inplace=True)

# 분석 제외할 변수 제거
exclude_columns = ["patient_ID", "me", "failure_reason", "failure_date", 
                   "last_fu_date", "surgery_Date", "fu_for_fail_yr", "fu_for_survival_yr"]
selected_features = [col for col in df.columns if col not in ["vit_status", "survival_time"] + exclude_columns] # 분석에 사용할 변수만 포함

# 지정한 컬럼들에 결측치가 있는 행 제거
df = df.dropna(subset = selected_features)

# 평가 함수가 기대하는 구조화 배열 생성 (필드명: 'vit_status', 'survival_time')
y = np.zeros(df.shape[0], dtype=[('vit_status', '?'), ('survival_time', '<f8')])
y['vit_status'] = df["vit_status"].values.astype(bool)
y['survival_time'] = df["survival_time"].values.astype(float)

X = pd.get_dummies(df[selected_features], drop_first=True) # get_dummies : 범주형 변수에 대해 One-hot encoding

In [3]:
# 설정
random_state = 1234
np.random.seed(random_state)
model_name = "rsf"

# ------------------------------
# 데이터 분할 및 K-Fold 설정
# ------------------------------

# Split between Test and Training for Hyperparameter Tuning
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=random_state)
kfold = KFold(n_splits=5, shuffle=True, random_state=random_state)

# ------------------------------
# config.yaml 불러오기 및 RSF 파라미터 갱신
# ------------------------------
config = yaml.safe_load(Path("./config.yaml").read_text())
base_path = config["base_path"]
rsf_config = config["rsf"]
# 'max_features' 최대값은 선택한 특성의 개수로 갱신
rsf_config["max_features"]["max"] = len(selected_features)

In [4]:
# ------------------------------
# Optuna Objective 함수 정의 (RSF)
# ------------------------------
def objective_rsf(trial: optuna.Trial):
    # Hyperparameter 범위 설정
    params = {
        "n_estimators": trial.suggest_int("n_estimators", rsf_config["n_estimators"]["min"], rsf_config["n_estimators"]["max"]),
        #"min_samples_split": trial.suggest_int("min_samples_split", max(2, rsf_config["min_samples_split"]["min"]), rsf_config["min_samples_split"]["max"]),
        "min_samples_leaf": trial.suggest_int("min_samples_leaf", rsf_config["min_samples_leaf"]["min"], rsf_config["min_samples_leaf"]["max"]),
        "max_features": trial.suggest_int("max_features", rsf_config["max_features"]["min"], rsf_config["max_features"]["max"]),
        "max_depth": trial.suggest_int("max_depth", rsf_config["max_depth"]["min"], rsf_config["max_depth"]["max"]),
        #"max_samples": trial.suggest_float("max_samples", rsf_config["max_samples"]["min"], rsf_config["max_samples"]["max"]),
    }
    scores = []
    
    # K-fold Cross Validation으로 계산한 K개의 Concordance index의 평균 계산
    for train_idx, _ in kfold.split(X_train, y_train):
        X_fold = X_train.iloc[train_idx]
        model = RandomSurvivalForest(
            n_estimators=params["n_estimators"],
            min_samples_leaf=params["min_samples_leaf"],
            max_features=params["max_features"],
            max_depth=params["max_depth"],
            n_jobs=-1,
            random_state=random_state
        )
        model.fit(X_fold, y_train[train_idx])
        score = model.score(X_test, y_test)
        scores.append(score)
    return np.mean(scores) 

In [5]:
# Optuna 스터디 생성 및 최적화
study = optuna.create_study(study_name=model_name+str(datetime.datetime.now()),
                                direction="maximize",
                                sampler=optuna.samplers.TPESampler(seed=random_state))
study.optimize(objective_rsf, n_trials=50)
best_params = study.best_trial.params
print("Best RSF parameters:", best_params) # 50번의 Trial로 찾은 최선의 Hyperparameter

[I 2025-03-06 14:55:25,753] A new study created in memory with name: rsf2025-03-06 14:55:25.750676
[I 2025-03-06 14:55:27,596] Trial 0 finished with value: 0.65 and parameters: {'n_estimators': 199, 'min_samples_leaf': 21, 'max_features': 7, 'max_depth': 12}. Best is trial 0 with value: 0.65.
[I 2025-03-06 14:55:35,644] Trial 1 finished with value: 0.6552631578947368 and parameters: {'n_estimators': 782, 'min_samples_leaf': 12, 'max_features': 5, 'max_depth': 13}. Best is trial 1 with value: 0.6552631578947368.
[I 2025-03-06 14:55:47,195] Trial 2 finished with value: 0.6236842105263158 and parameters: {'n_estimators': 959, 'min_samples_leaf': 27, 'max_features': 6, 'max_depth': 9}. Best is trial 1 with value: 0.6552631578947368.
[I 2025-03-06 14:55:52,573] Trial 3 finished with value: 0.6381578947368421 and parameters: {'n_estimators': 687, 'min_samples_leaf': 23, 'max_features': 6, 'max_depth': 9}. Best is trial 1 with value: 0.6552631578947368.
[I 2025-03-06 14:55:56,941] Trial 4 fin

Best RSF parameters: {'n_estimators': 704, 'min_samples_leaf': 20, 'max_features': 11, 'max_depth': 5}


In [9]:
# 각 fold의 결과를 저장할 딕셔너리 생성
fold_scores = {}


# plt.rcParams['font.family'] ='Malgun Gothic'
# plt.rcParams['axes.unicode_minus'] =False

# 최종 모델 학습 및 결과 저장
for i, (train_fold, val_fold) in enumerate(kfold.split(X_train, y_train)):
    X_train_fold = X_train.iloc[train_fold]
    X_val_fold = X_train.iloc[val_fold]
    best_model = RandomSurvivalForest(**best_params, random_state=random_state, n_jobs=-1)
    best_model.fit(X_train_fold, y_train[train_fold])
    
    # 평가 결과 저장
    scores = evaluate_survival_model(best_model, X_val_fold, y_train[train_fold], y_train[val_fold])
    print(f"Final RSF Scores in Fold {i}: {scores}")
    fold_scores[f"fold_{i}"] = scores  

    # ---- Permutation Importance 저장 ----
    result = permutation_importance(best_model, X_val_fold, y_train[val_fold], n_repeats=15, random_state=random_state)
    result_dict = {k: result[k] for k in ("importances_mean", "importances_std")}
    permutation_importances = pd.DataFrame(result_dict, index=X_val_fold.columns).sort_values(by="importances_mean", ascending=False)
    
    print("Permutation Importances:")
    print(permutation_importances)

    perm_imp_path = f"RSF_permutation_importances_fold_{i}.csv"
    permutation_importances.to_csv(perm_imp_path, encoding = 'utf-8')
    print(f"Permutation importances saved to {perm_imp_path}")
    
    # ---- SHAP values 저장 ----
    explainer = shap.Explainer(best_model.predict, X_val_fold)
    shap_values = explainer(X_val_fold)
    
    # ---- SHAP Beeswarm Plot 저장 ----
    plt.figure(figsize=(10, 6))  # 적절한 크기 설정
    shap.plots.beeswarm(shap_values, show=False)
    
    beeswarm_path = f"RSF_beeswarm_fold_{i}.png"
    plt.tight_layout()  # 여백 자동 조정
    plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
    plt.close()  # 메모리 정리
    print(f"Beeswarm plot saved to {beeswarm_path}")
    
# Final Scores
fold_scores = pd.DataFrame(fold_scores).T
final_scores = fold_scores.mean(skipna=True)
print(final_scores)

  true_pos = cumsum_tp / cumsum_tp[-1]


Final RSF Scores in Fold 0: {'c_index': 0.5552699228791774, 'mean_auc': nan, 'ibs': 0.07226376390905163}
Permutation Importances:
                                          importances_mean  importances_std
prosthesis_type_single                            0.137618         0.106779
tooth_loss_reason_perio                           0.097858         0.054851
Sex_M                                             0.034276         0.034267
implant_length_group_길이 10 이상                     0.009254         0.013948
type_of_disability_Group2_Non-Mental              0.005656         0.025226
Systemic_disease_y                                0.005484         0.014538
implant_diameter_group_wide: 5>=                  0.004284         0.007311
implant_diameter_group_regular: 4>= & 5<          0.000943         0.008310
Age_old                                           0.000514         0.003778
compliance_with_SPT_non                           0.000171         0.008195
prosthesis_type_overdenture       

PermutationExplainer explainer: 99it [07:22,  4.61s/it]                        


Beeswarm plot saved to RSF_beeswarm_fold_0.png
Final RSF Scores in Fold 1: {'c_index': 0.728110599078341, 'mean_auc': 0.726331115644746, 'ibs': 0.027670988254160284}
Permutation Importances:
                                          importances_mean  importances_std
prosthesis_type_single                            0.333026         0.094359
tooth_loss_reason_perio                           0.194163         0.128712
Sex_M                                             0.031336         0.073775
type_of_disability_Group2_Non-Mental              0.010138         0.014785
implant_length_group_길이 10 이상                     0.009831         0.005011
Age_old                                           0.004301         0.009134
prosthesis_type_overdenture                       0.000614         0.001567
compliance_with_SPT_non                          -0.001536         0.004659
bone_augmentation_procedure_y                    -0.002458         0.003309
jaw_mx                                           

PermutationExplainer explainer: 99it [07:22,  4.61s/it]                        


Beeswarm plot saved to RSF_beeswarm_fold_1.png
Final RSF Scores in Fold 2: {'c_index': 0.5400355871886121, 'mean_auc': 0.5859466861407313, 'ibs': 0.04844217839848014}
Permutation Importances:
                                          importances_mean  importances_std
implant_site_p                                    0.103737         0.077846
prosthesis_type_single                            0.052906         0.018169
Sex_M                                             0.019810         0.018444
implant_length_group_길이 10 이상                     0.012633         0.004150
type_of_disability_Group2_Non-Mental              0.006406         0.009676
compliance_with_SPT_non                           0.004448         0.005064
prosthesis_type_overdenture                       0.000356         0.002828
bone_augmentation_procedure_y                    -0.000178         0.007757
implant_diameter_group_wide: 5>=                 -0.003203         0.013461
periodontal_diagnosis_group_stage 3,4           

PermutationExplainer explainer: 99it [07:26,  4.65s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to RSF_beeswarm_fold_2.png
Final RSF Scores in Fold 3: {'c_index': 0.8206521739130435, 'mean_auc': 0.8679510931739639, 'ibs': 0.04237957274490644}
Permutation Importances:
                                          importances_mean  importances_std
implant_site_p                                    0.172464         0.051771
prosthesis_type_single                            0.149758         0.063807
tooth_loss_reason_perio                           0.068720         0.040740
Sex_M                                             0.030314         0.015800
type_of_disability_Group2_Non-Mental              0.010628         0.006611
Systemic_disease_y                                0.010386         0.005860
implant_diameter_group_regular: 4>= & 5<          0.009179         0.004049
periodontal_diagnosis_group_stage 3,4             0.006099         0.011658
compliance_with_SPT_non                           0.003744         0.005631
implant_length_group_길이 10 이상                   

PermutationExplainer explainer: 98it [06:50,  4.27s/it]                        


Beeswarm plot saved to RSF_beeswarm_fold_3.png
Final RSF Scores in Fold 4: {'c_index': 0.8812877263581489, 'mean_auc': 0.9060668009689415, 'ibs': 0.04414734628293585}
Permutation Importances:
                                          importances_mean  importances_std
implant_site_p                                    0.388867         0.057675
tooth_loss_reason_perio                           0.086117         0.030990
prosthesis_type_single                            0.051241         0.075941
Sex_M                                             0.038766         0.029264
periodontal_diagnosis_group_stage 3,4             0.038766         0.021763
prosthesis_type_overdenture                       0.024816         0.001749
type_of_disability_Group2_Non-Mental              0.023206         0.010230
jaw_mx                                            0.012743         0.017162
Age_old                                           0.012072         0.004529
implant_diameter_group_regular: 4>= & 5<        

PermutationExplainer explainer: 98it [07:21,  4.60s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to RSF_beeswarm_fold_4.png
c_index     0.705071
mean_auc    0.771574
ibs         0.046981
dtype: float64
