In [1]:
import pickle
import datetime
from pathlib import Path
import yaml
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored
from sksurv.util import Surv
from sklearn.inspection import permutation_importance
import optuna
import shap
from explainability import SHAP
from evaluation import evaluate_survival_model, PartialLogLikelihood
from training_survival_analysis import train_model
from models import MinimalisticNetwork
import matplotlib.pyplot as plt
import json

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset


# DeepSurv 사용 위한 Dataset 클래스 정의
# 간단한 Dataset 클래스: 이미 수치형 데이터로 준비되었다고 가정. 인코딩 후 사용
class SimpleDataset(Dataset):
    def __init__(self, X, y_time, y_event):
        """
        Args:
            X: 입력 특성. Pandas DataFrame 또는 numpy array (수치형 데이터)
            y_time: 생존 시간. Pandas Series 또는 numpy array (float)
            y_event: 이벤트(실패 여부). Pandas Series 또는 numpy array (0/1 숫자)
        """
        if isinstance(X, pd.DataFrame):
            self.X = X.values.astype(np.float32)
        else:
            self.X = np.asarray(X, dtype=np.float32)
        self.y_time = np.asarray(y_time, dtype=np.float32)
        self.y_event = np.asarray(y_event, dtype=np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return (torch.tensor(self.X[index], dtype=torch.float32),
                torch.tensor(self.y_time[index], dtype=torch.float32),
                torch.tensor(self.y_event[index], dtype=torch.float32))

In [2]:

def prepare_deepsurv_data(df, time_col, event_col, random_state=1234, test_size=0.1, n_splits=5):
    """
    DeepSurv 모델 학습을 위한 데이터 준비 함수.
    
    Args:
        df: 전처리된 데이터프레임 (결측치 없는 상태)
        time_col: 생존 시간을 나타내는 컬럼명 (ex: 'fu_total_yr')
        event_col: 생존 여부를 나타내는 컬럼명 (ex: 'survival')
        random_state: 재현성을 위한 랜덤 시드
        test_size: Train / test set split 비율. 0.1 디폴트
        n_splits: cross validation fold 수. 5 디폴트

    Returns:
        X_train, X_test, y_train, y_test, kfold
    """
    # 독립변수(X) One-Hot Encoding
    X = pd.get_dummies(df.drop(columns=[time_col, event_col]), drop_first=True).astype(np.float32)

    # 종속변수(y)를 구조화 배열로 변환
    y = np.zeros(df.shape[0], dtype=[('vit_status', '?'), ('survival_time', '<f8')])
    y['vit_status'] = df[event_col].values.astype(bool)
    y['survival_time'] = df[time_col].values.astype(float)
    
    np.random.seed(random_state)
    model_name = "deepsurv"
    # 데이터 분할 (Train/Test Split)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # KFold 설정
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    return X_train, X_test, y_train, y_test, kfold

In [3]:

def optimize_deepsurv(X_train, y_train, kfold, random_state=1234):
    """
    Optuna를 사용해 DeepSurv 모델의 최적 하이퍼파라미터를 찾는 함수.

    Args:
        X_train: 훈련 데이터
        y_train: 생존 분석 라벨
        kfold: K-Fold 객체
        random_state: 랜덤 시드

    Returns:
        best_params: 최적의 하이퍼파라미터
    """
    # 설정 파일 로드
    config = yaml.safe_load(Path("./config.yaml").read_text())
    base_path = config["base_path"]
    device = config["device"]
    deepsurv_config = config["deep_surv"]
    np.random.seed(random_state)
    model_name = "deepsurv"


    stable_params = {
        "device": device,
        "input_dim": X_train.shape[1],
        "loss_fn": PartialLogLikelihood,
        "epochs": 300, # epochs 300개 사용
        "model": "minimalistic_network"
        }

    # DeepSurv 입력값에 맞게 tensor 형태로 변환
    X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
    y_train_time = pd.Series(y_train['survival_time'])
    y_train_event = pd.Series(y_train['vit_status'])
    y_train_time_tensor = torch.tensor(np.ascontiguousarray(y_train_time.values), dtype=torch.float32)
    y_train_event_tensor = torch.tensor(np.ascontiguousarray(y_train_event.values), dtype=torch.float32)
    train_dataset = TensorDataset(X_train_tensor, y_train_time_tensor, y_train_event_tensor)

    def objective_deep_surv(trial):
        """Optuna에서 사용할 목적 함수"""
        flexible_params = {
            "batch_size": trial.suggest_categorical("batch_size", deepsurv_config["batch_size"]),
            "inner_dim": trial.suggest_categorical("inner_dim", deepsurv_config["inner_dim"]),
            "lr": trial.suggest_categorical("lr", deepsurv_config["lr"]),
            "weight_decay": trial.suggest_categorical("weight_decay", deepsurv_config["weight_decay"])
        }
        params = {**stable_params, **flexible_params}
        scores = []

        dataset_test = torch.Tensor(X_test.values)

        for train_idx, val_idx in kfold.split(X_train, y_train):
            X_train_fold = X_train.iloc[train_idx]
            dataset_train = SimpleDataset(X_train_fold, y_train["survival_time"][train_idx], y_train["vit_status"][train_idx]) # class 변환
            model, _, _ = train_model(dataset_train, params, trial=trial)

            model.eval()
            y_pred = model(dataset_test.to(params["device"])).detach().cpu().numpy()
            y_pred = y_pred + np.random.random(y_pred.shape) * 1e-7

            try:
                fold_score = concordance_index_censored(y_test["vit_status"], y_test["survival_time"], np.squeeze(y_pred))[0]
            except ValueError: # 예측값에 성공/실패 중 하나가 아예 없는 경우 오류 발생. 해당 과정 스킵
                continue
            scores.append(fold_score)

        return np.mean(scores) if scores else 0.0 # 모든 fold에서 오류가 발생한 경우, 기본값 반환 (예: 0)

    # Optuna 실행
    study = optuna.create_study(study_name=model_name+str(datetime.datetime.now()),
                                direction="maximize",
                                sampler=optuna.samplers.TPESampler(seed=random_state))
    study.optimize(objective_deep_surv, n_trials=50)

    return study.best_trial.params


In [4]:

def train_and_evaluate_deepsurv(X_train, X_test, y_train, y_test, best_params, kfold, random_state=1234):
    """
    DeepSurv 모델을 학습하고 평가하는 함수.
    
    Args:
        X_train, X_test, y_train, y_test: 훈련 및 테스트 데이터
        best_params: Optuna에서 찾은 최적 하이퍼파라미터
        kfold: K-Fold 객체
        random_state: 랜덤 시드
    
    Returns:
        final_scores: 각 fold의 평가 점수 평균값 (Concordance Index 등)
    """

    # 설정 파일 로드
    config = yaml.safe_load(Path("./config.yaml").read_text())
    base_path = config["base_path"]
    device = config["device"]
    deepsurv_config = config["deep_surv"]
    np.random.seed(random_state)
    model_name = "deepsurv"

    stable_params = {
        "device": device,
        "input_dim": X_train.shape[1],
        "loss_fn": PartialLogLikelihood,
        "epochs": 300, # epochs 300개 사용
        "model": "minimalistic_network"
        }
    
    fold_scores = {}

    for i, (train_fold, val_fold) in enumerate(kfold.split(X_train, y_train)):
        X_train_fold = X_train.iloc[train_fold]
        X_val_fold = X_train.iloc[val_fold]
        
        # 학습 데이터셋 생성
        dataset_train = SimpleDataset(
            X_train_fold,
            y_train["vit_status"][train_fold],
            y_train["survival_time"][train_fold]
        )

        # 모델 학습
        best_model, _, _ = train_model(dataset_train, {**stable_params, **best_params})
        best_model.eval()

        # 예측: device에 맞게 Tensor 변환 후 예측 수행
        y_pred = best_model(torch.Tensor(X_val_fold.values).to(stable_params["device"])).detach().cpu().numpy()

        # 평가: evaluate_survival_model 함수 사용
        scores = evaluate_survival_model(best_model, X_val_fold.values, y_train[train_fold], y_train[val_fold])
        print(f"Final DeepSurv Scores in Fold {i}: {scores}")
        fold_scores[f"fold_{i}"] = scores  # scores가 dict 형식일 경우 그대로 저장

        # Permutation Importance 계산 및 저장
        result = permutation_importance(
            best_model, X_val_fold, y_train[val_fold], n_repeats=15, random_state=random_state)
        result_dict = {k: result[k] for k in ("importances_mean", "importances_std")}
        permutation_importances = pd.DataFrame(result_dict, index=X_val_fold.columns).sort_values(by="importances_mean", ascending=False)
        
        print("Permutation Importances:")
        print(permutation_importances)
        
        # 각 fold의 permutation importance CSV 파일로 저장
        perm_imp_path = f"DeepSurv_permutation_importances_fold_{i}.csv"
        permutation_importances.to_csv(perm_imp_path, encoding='utf-8')
        print(f"Permutation importances saved to {perm_imp_path}")

        # ---- SHAP values 저장 ----
        explainer = shap.Explainer(best_model.predict, X_val_fold.values, feature_names=X_val_fold.columns.tolist())  
        shap_values = explainer(X_val_fold.values)
        
        # ---- SHAP Beeswarm Plot 저장 ----
        plt.figure(figsize=(10, 6))  # 적절한 크기 설정
        shap.plots.beeswarm(shap_values, show=False)
        
        beeswarm_path = f"DeepSurv_beeswarm_fold_{i}.png"
        plt.tight_layout()  # 여백 자동 조정
        plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
        plt.close()  # 메모리 정리
        print(f"Beeswarm plot saved to {beeswarm_path}")

    # Final Scores (5-fold 결과 평균)
    fold_scores_df = pd.DataFrame(fold_scores).T
    final_scores = fold_scores_df.mean(skipna=True)
    print("Final DeepSurv Scores:")
    print(final_scores)

    return final_scores


In [5]:
############# 전처리 ################

df = pd.read_csv("processed_survival_data_modified.csv") # processed_survival_data_modified : Age 범주화한 데이터셋
df["implant_length_group"] = df["implant_length_group"].apply(
    lambda x: "Length ≥ 10" if x == "길이 10 이상" else "Length < 10"
)
df["survival"] = df["survival"].map({"survive": 0, "fail": 1}) # 종속변수(수술 성공 여부)를 0, 1로 변환
# 분석 제외할 변수 제거
exclude_columns = ["patient_ID", "me", "failure_reason", "failure_date", 
                   "last_fu_date", "surgery_Date", "fu_for_fail_yr", "fu_for_survival_yr"]
all_columns = [col for col in df.columns if col not in exclude_columns] # 분석에 사용할 변수만 포함
# 지정한 컬럼들에 결측치가 있는 행 제거
df = df.dropna(subset = all_columns)
df = df[all_columns]
selected_features = [col for col in df.columns if col not in ['fu_total_yr', 'survival'] + exclude_columns]
time_col = "fu_total_yr"
event_col = "survival"

In [6]:
########## 최종 실행 코드 #############
X_train, X_test, y_train, y_test, kfold = prepare_deepsurv_data(df, "fu_total_yr", "survival")
best_params = optimize_deepsurv(X_train, y_train, kfold)
fold_scores_df = train_and_evaluate_deepsurv(X_train, X_test, y_train, y_test, best_params, kfold)

[I 2025-03-09 09:50:26,397] A new study created in memory with name: deepsurv2025-03-09 09:50:26.396449
[I 2025-03-09 09:50:57,999] Trial 0 finished with value: 0.588157894736842 and parameters: {'batch_size': 3540, 'inner_dim': 64, 'lr': 0.001, 'weight_decay': 5}. Best is trial 0 with value: 0.588157894736842.
[I 2025-03-09 09:51:21,662] Trial 1 finished with value: 0.475 and parameters: {'batch_size': 512, 'inner_dim': 16, 'lr': 5e-05, 'weight_decay': 5}. Best is trial 0 with value: 0.588157894736842.
[I 2025-03-09 09:51:43,984] Trial 2 finished with value: 0.5526315789473685 and parameters: {'batch_size': 1024, 'inner_dim': 16, 'lr': 0.0005, 'weight_decay': 1}. Best is trial 0 with value: 0.588157894736842.
[I 2025-03-09 09:52:05,981] Trial 3 finished with value: 0.5144736842105264 and parameters: {'batch_size': 512, 'inner_dim': 8, 'lr': 0.0001, 'weight_decay': 5}. Best is trial 0 with value: 0.588157894736842.
[I 2025-03-09 09:52:29,284] Trial 4 finished with value: 0.602631578947

-0.022365076
Final DeepSurv Scores in Fold 0: {'c_index': 0.5989717223650386, 'mean_auc': nan, 'ibs': 0.09689248721495977}
Permutation Importances:
                                          importances_mean  importances_std
implant_diameter_group_regular: 4>= & 5<          0.076007         0.033629
prosthesis_type_overdenture                       0.029649         0.026730
prosthesis_type_single                            0.026564         0.035728
type_of_disability_Group2_Non-Mental              0.021251         0.019385
Age_young                                         0.009769         0.029953
compliance_with_SPT_non                           0.009597         0.014304
Age_old                                           0.004627         0.026834
tooth_loss_reason_perio                           0.003085         0.042658
Systemic_disease_y                               -0.007027         0.048517
bone_augmentation_procedure_y                    -0.007883         0.028152
implant_site_p  

PermutationExplainer explainer: 99it [02:19,  1.46s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_0.png
-0.15966472
Final DeepSurv Scores in Fold 1: {'c_index': 0.5576036866359447, 'mean_auc': 0.5855385853881027, 'ibs': 0.02955043281314792}
Permutation Importances:
                                          importances_mean  importances_std
implant_length_group_길이 10 이상                     0.146851         0.067746
tooth_loss_reason_perio                           0.065131         0.064381
Age_young                                         0.053763         0.021062
implant_diameter_group_wide: 5>=                  0.051613         0.050878
prosthesis_type_overdenture                       0.035945         0.018139
type_of_disability_Group2_Non-Mental              0.009217         0.037964
Age_old                                           0.000922         0.044812
prosthesis_type_single                            0.000614         0.076833
implant_diameter_group_regular: 4>= & 5<         -0.001229         0.075637
implant_site_p            

PermutationExplainer explainer: 99it [02:13,  1.46s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_1.png
0.20134309
Final DeepSurv Scores in Fold 2: {'c_index': 0.46708185053380785, 'mean_auc': 0.4069291489174489, 'ibs': 0.061503425642588566}
Permutation Importances:
                                          importances_mean  importances_std
Sex_M                                             0.028588         0.045532
jaw_mx                                            0.021174         0.069520
implant_diameter_group_wide: 5>=                  0.020285         0.063364
implant_site_p                                    0.008007         0.038559
prosthesis_type_overdenture                       0.005338         0.048945
bone_augmentation_procedure_y                     0.005042         0.027245
implant_diameter_group_regular: 4>= & 5<         -0.010380         0.057947
compliance_with_SPT_non                          -0.017260         0.023376
periodontal_diagnosis_group_stage 3,4            -0.022005         0.032690
implant_length_group_길이 1

PermutationExplainer explainer: 99it [02:12,  1.46s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_2.png
0.2363606
Final DeepSurv Scores in Fold 3: {'c_index': 0.657608695652174, 'mean_auc': 0.6104031419073536, 'ibs': 0.060821912729789124}
Permutation Importances:
                                          importances_mean  importances_std
jaw_mx                                            0.086232         0.059413
Sex_M                                             0.072947         0.050920
type_of_disability_Group2_Non-Mental              0.043237         0.066163
Age_old                                           0.026449         0.015693
implant_site_p                                    0.025242         0.022900
implant_diameter_group_wide: 5>=                  0.016667         0.033162
prosthesis_type_overdenture                       0.013406         0.014045
tooth_loss_reason_perio                           0.007971         0.031415
Systemic_disease_y                                0.006763         0.038880
bone_augmentation_procedure_

PermutationExplainer explainer: 98it [02:11,  1.46s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_3.png
0.2362793
Final DeepSurv Scores in Fold 4: {'c_index': 0.4909456740442656, 'mean_auc': 0.5287176496952938, 'ibs': 0.07080269036121316}
Permutation Importances:
                                          importances_mean  importances_std
compliance_with_SPT_erratic                       0.090543         0.028073
prosthesis_type_overdenture                       0.054192         0.018183
Sex_M                                             0.035010         0.034255
tooth_loss_reason_perio                           0.032327         0.035139
implant_site_p                                    0.027767         0.028171
Age_young                                         0.013548         0.020127
periodontal_diagnosis_group_stage 3,4             0.008451         0.037414
type_of_disability_Group2_Non-Mental              0.005902         0.015611
implant_length_group_길이 10 이상                     0.005097         0.025269
bone_augmentation_procedure_

PermutationExplainer explainer: 98it [02:11,  1.46s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_4.png
Final DeepSurv Scores:
c_index     0.554442
mean_auc    0.532897
ibs         0.063914
dtype: float64
c_index     0.554442
mean_auc    0.532897
ibs         0.063914
dtype: float64


In [6]:
X_train, X_test, y_train, y_test, kfold = prepare_deepsurv_data(df, "fu_total_yr", "survival")
best_params = {'batch_size': 1024, 'inner_dim': 64, 'lr': 1e-05, 'weight_decay': 0.01}
fold_scores_df = train_and_evaluate_deepsurv(X_train, X_test, y_train, y_test, best_params, kfold)

  true_pos = cumsum_tp / cumsum_tp[-1]


0.019319084
Final DeepSurv Scores in Fold 0: {'c_index': 0.5912596401028277, 'mean_auc': nan, 'ibs': 0.10353907247237816}
Permutation Importances:
                                          importances_mean  importances_std
implant_diameter_group_regular: 4>= & 5<          0.060583         0.035409
type_of_disability_Group2_Non-Mental              0.028620         0.026062
prosthesis_type_single                            0.025536         0.028183
prosthesis_type_overdenture                       0.018680         0.026953
compliance_with_SPT_non                           0.018680         0.022659
Age_old                                           0.008055         0.020860
Age_young                                         0.001885         0.034058
tooth_loss_reason_perio                          -0.004284         0.043864
Sex_M                                            -0.007027         0.069218
bone_augmentation_procedure_y                    -0.007712         0.034502
implant_length_gr

PermutationExplainer explainer: 99it [02:21,  1.45s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_0.png


  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator / denominator
  cindex = numerator

-0.11617483
Final DeepSurv Scores in Fold 1: {'c_index': 0.5529953917050692, 'mean_auc': 0.525546104087921, 'ibs': 0.03142389966638173}
Permutation Importances:
                                          importances_mean  importances_std
implant_length_group_길이 10 이상                     0.157911         0.063039
Age_young                                         0.052842         0.021996
implant_diameter_group_wide: 5>=                  0.035023         0.054599
prosthesis_type_overdenture                       0.034716         0.013657
tooth_loss_reason_perio                           0.034716         0.058142
prosthesis_type_single                            0.009524         0.085097
implant_site_p                                    0.001843         0.034477
type_of_disability_Group2_Non-Mental             -0.000922         0.059569
bone_augmentation_procedure_y                    -0.012903         0.053588
Age_old                                          -0.015361         0.048644
imp

PermutationExplainer explainer: 99it [02:12,  1.46s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_1.png
0.18338749
Final DeepSurv Scores in Fold 2: {'c_index': 0.452846975088968, 'mean_auc': 0.3549270914137278, 'ibs': 0.06316063497974378}
Permutation Importances:
                                          importances_mean  importances_std
implant_diameter_group_wide: 5>=                  0.024911         0.059115
prosthesis_type_overdenture                       0.016133         0.046923
bone_augmentation_procedure_y                     0.010973         0.037631
Sex_M                                            -0.001542         0.041898
jaw_mx                                           -0.007888         0.071447
implant_site_p                                   -0.012989         0.040305
compliance_with_SPT_erratic                      -0.016251         0.039226
compliance_with_SPT_non                          -0.031376         0.023382
implant_length_group_길이 10 이상                    -0.032918         0.039484
prosthesis_type_single      

PermutationExplainer explainer: 99it [02:13,  1.47s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_2.png
0.22711506
Final DeepSurv Scores in Fold 3: {'c_index': 0.6141304347826086, 'mean_auc': 0.5401079712774866, 'ibs': 0.05955015236923045}
Permutation Importances:
                                          importances_mean  importances_std
jaw_mx                                            0.112681         0.056593
Sex_M                                             0.056401         0.048012
type_of_disability_Group2_Non-Mental              0.033937         0.069020
Age_old                                           0.031401         0.017522
implant_diameter_group_wide: 5>=                  0.019324         0.034857
prosthesis_type_overdenture                       0.012319         0.010505
bone_augmentation_procedure_y                     0.009300         0.014186
implant_site_p                                    0.007729         0.031290
compliance_with_SPT_erratic                       0.006763         0.019684
implant_length_group_길이 10 

PermutationExplainer explainer: 98it [02:11,  1.46s/it]                        


Beeswarm plot saved to DeepSurv_beeswarm_fold_3.png
0.18927641
Final DeepSurv Scores in Fold 4: {'c_index': 0.3983903420523139, 'mean_auc': 0.4220969965894068, 'ibs': 0.06527735044527594}
Permutation Importances:
                                          importances_mean  importances_std
compliance_with_SPT_erratic                       0.067337         0.023941
prosthesis_type_overdenture                       0.019450         0.018741
bone_augmentation_procedure_y                     0.008182         0.013852
Age_young                                         0.000402         0.015415
tooth_loss_reason_perio                          -0.001207         0.034576
implant_length_group_길이 10 이상                    -0.005768         0.037649
type_of_disability_Group2_Non-Mental             -0.010731         0.016341
Sex_M                                            -0.011536         0.034566
periodontal_diagnosis_group_stage 3,4            -0.011804         0.047557
compliance_with_SPT_non    

PermutationExplainer explainer: 98it [02:11,  1.46s/it]                        
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.tight_layout()  # 여백 자동 조정
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
  plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지


Beeswarm plot saved to DeepSurv_beeswarm_fold_4.png
Final DeepSurv Scores:
c_index     0.521925
mean_auc    0.460670
ibs         0.064590
dtype: float64
