In [None]:
import pickle
import datetime
from pathlib import Path
import yaml
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored
from sksurv.util import Surv
from sklearn.inspection import permutation_importance
import optuna
import shap
from explainability import SHAP
from evaluation import evaluate_survival_model, PartialLogLikelihood
from training_survival_analysis import train_model
from models import MinimalisticNetwork
import matplotlib.pyplot as plt
import json

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset


# DeepSurv 사용 위한 Dataset 클래스 정의
# 간단한 Dataset 클래스: 이미 수치형 데이터로 준비되었다고 가정. 인코딩 후 사용
class SimpleDataset(Dataset):
    def __init__(self, X, y_time, y_event):
        """
        Args:
            X: 입력 특성. Pandas DataFrame 또는 numpy array (수치형 데이터)
            y_time: 생존 시간. Pandas Series 또는 numpy array (float)
            y_event: 이벤트(실패 여부). Pandas Series 또는 numpy array (0/1 숫자)
        """
        if isinstance(X, pd.DataFrame):
            self.X = X.values.astype(np.float32)
        else:
            self.X = np.asarray(X, dtype=np.float32)
        self.y_time = np.asarray(y_time, dtype=np.float32)
        self.y_event = np.asarray(y_event, dtype=np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return (torch.tensor(self.X[index], dtype=torch.float32),
                torch.tensor(self.y_time[index], dtype=torch.float32),
                torch.tensor(self.y_event[index], dtype=torch.float32))

In [None]:
def prepare_deepsurv_data(df, time_col, event_col, random_state=1234, test_size=0.1, n_splits=5):
    """
    DeepSurv 모델 학습을 위한 데이터 준비 함수.
    
    Args:
        df: 전처리된 데이터프레임 (결측치 없는 상태)
        time_col: 생존 시간을 나타내는 컬럼명 (ex: 'fu_total_yr')
        event_col: 생존 여부를 나타내는 컬럼명 (ex: 'survival')
        random_state: 재현성을 위한 랜덤 시드
        test_size: Train / test set split 비율. 0.1 디폴트
        n_splits: cross validation fold 수. 5 디폴트

    Returns:
        X_train, X_test, y_train, y_test, kfold
    """
    # 독립변수(X) One-Hot Encoding
    X = pd.get_dummies(df.drop(columns=[time_col, event_col]), drop_first=True).astype(np.float32)

    # 종속변수(y)를 구조화 배열로 변환
    y = np.zeros(df.shape[0], dtype=[('vit_status', '?'), ('survival_time', '<f8')])
    y['vit_status'] = df[event_col].values.astype(bool)
    y['survival_time'] = df[time_col].values.astype(float)
    
    np.random.seed(random_state)
    model_name = "deepsurv"
    # 데이터 분할 (Train/Test Split)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # KFold 설정
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)

    return X_train, X_test, y_train, y_test, kfold

In [None]:
def optimize_deepsurv(X_train, y_train, kfold, random_state=1234):
    """
    Optuna를 사용해 DeepSurv 모델의 최적 하이퍼파라미터를 찾는 함수.

    Args:
        X_train: 훈련 데이터
        y_train: 생존 분석 라벨
        kfold: K-Fold 객체
        random_state: 랜덤 시드

    Returns:
        best_params: 최적의 하이퍼파라미터
    """
    # 설정 파일 로드
    config = yaml.safe_load(Path("./config.yaml").read_text())
    base_path = config["base_path"]
    device = config["device"]
    deepsurv_config = config["deep_surv"]
    np.random.seed(random_state)
    model_name = "deepsurv"


    stable_params = {
        "device": device,
        "input_dim": X_train.shape[1],
        "loss_fn": PartialLogLikelihood,
        "epochs": 300, # epochs 300개 사용
        "model": "minimalistic_network"
        }

    # DeepSurv 입력값에 맞게 tensor 형태로 변환
    X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
    y_train_time = pd.Series(y_train['survival_time'])
    y_train_event = pd.Series(y_train['vit_status'])
    y_train_time_tensor = torch.tensor(np.ascontiguousarray(y_train_time.values), dtype=torch.float32)
    y_train_event_tensor = torch.tensor(np.ascontiguousarray(y_train_event.values), dtype=torch.float32)
    train_dataset = TensorDataset(X_train_tensor, y_train_time_tensor, y_train_event_tensor)

    def objective_deep_surv(trial):
        """Optuna에서 사용할 목적 함수"""
        flexible_params = {
            "batch_size": trial.suggest_categorical("batch_size", deepsurv_config["batch_size"]),
            "inner_dim": trial.suggest_categorical("inner_dim", deepsurv_config["inner_dim"]),
            "lr": trial.suggest_categorical("lr", deepsurv_config["lr"]),
            "weight_decay": trial.suggest_categorical("weight_decay", deepsurv_config["weight_decay"])
        }
        params = {**stable_params, **flexible_params}
        scores = []

        dataset_test = torch.Tensor(X_test.values)

        for train_idx, val_idx in kfold.split(X_train, y_train):
            X_train_fold = X_train.iloc[train_idx]
            dataset_train = SimpleDataset(X_train_fold, y_train["survival_time"][train_idx], y_train["vit_status"][train_idx]) # class 변환
            model, _, _ = train_model(dataset_train, params, trial=trial)

            model.eval()
            y_pred = model(dataset_test.to(params["device"])).detach().cpu().numpy()
            y_pred = y_pred + np.random.random(y_pred.shape) * 1e-7

            try:
                fold_score = concordance_index_censored(y_test["vit_status"], y_test["survival_time"], np.squeeze(y_pred))[0]
            except ValueError: # 예측값에 성공/실패 중 하나가 아예 없는 경우 오류 발생. 해당 과정 스킵
                continue
            scores.append(fold_score)

        return np.mean(scores) if scores else 0.0 # 모든 fold에서 오류가 발생한 경우, 기본값 반환 (예: 0)

    # Optuna 실행
    study = optuna.create_study(study_name=model_name+str(datetime.datetime.now()),
                                direction="maximize",
                                sampler=optuna.samplers.TPESampler(seed=random_state))
    study.optimize(objective_deep_surv, n_trials=50)

    return study.best_trial.params

In [None]:
def train_and_evaluate_deepsurv(X_train, X_test, y_train, y_test, best_params, kfold, random_state=1234):
    """
    DeepSurv 모델을 학습하고 평가하는 함수.
    
    Args:
        X_train, X_test, y_train, y_test: 훈련 및 테스트 데이터
        best_params: Optuna에서 찾은 최적 하이퍼파라미터
        kfold: K-Fold 객체
        random_state: 랜덤 시드
    
    Returns:
        final_scores: 각 fold의 평가 점수 평균값 (Concordance Index 등)
    """

    # 설정 파일 로드
    config = yaml.safe_load(Path("./config.yaml").read_text())
    base_path = config["base_path"]
    device = config["device"]
    deepsurv_config = config["deep_surv"]
    np.random.seed(random_state)
    model_name = "deepsurv"

    stable_params = {
        "device": device,
        "input_dim": X_train.shape[1],
        "loss_fn": PartialLogLikelihood,
        "epochs": 300, # epochs 300개 사용
        "model": "minimalistic_network"
        }
    
    fold_scores = {}

    for i, (train_fold, val_fold) in enumerate(kfold.split(X_train, y_train)):
        X_train_fold = X_train.iloc[train_fold]
        X_val_fold = X_train.iloc[val_fold]
        
        # 학습 데이터셋 생성
        dataset_train = SimpleDataset(
            X_train_fold,
            y_train["vit_status"][train_fold],
            y_train["survival_time"][train_fold]
        )

        # 모델 학습
        best_model, _, _ = train_model(dataset_train, {**stable_params, **best_params})
        best_model.eval()

        # 예측: device에 맞게 Tensor 변환 후 예측 수행
        y_pred = best_model(torch.Tensor(X_val_fold.values).to(stable_params["device"])).detach().cpu().numpy()

        # 평가: evaluate_survival_model 함수 사용
        scores = evaluate_survival_model(best_model, X_val_fold.values, y_train[train_fold], y_train[val_fold])
        print(f"Final DeepSurv Scores in Fold {i}: {scores}")
        fold_scores[f"fold_{i}"] = scores  # scores가 dict 형식일 경우 그대로 저장

        # Permutation Importance 계산 및 저장
        result = permutation_importance(
            best_model, X_val_fold, y_train[val_fold], n_repeats=15, random_state=random_state)
        result_dict = {k: result[k] for k in ("importances_mean", "importances_std")}
        permutation_importances = pd.DataFrame(result_dict, index=X_val_fold.columns).sort_values(by="importances_mean", ascending=False)
        
        print("Permutation Importances:")
        print(permutation_importances)
        
        # 각 fold의 permutation importance CSV 파일로 저장
        perm_imp_path = f"DeepSurv_permutation_importances_fold_{i}.csv"
        permutation_importances.to_csv(perm_imp_path, encoding='utf-8')
        print(f"Permutation importances saved to {perm_imp_path}")

        # ---- SHAP values 저장 ----
        explainer = shap.Explainer(best_model.predict, X_val_fold.values, feature_names=X_val_fold.columns.tolist())  
        shap_values = explainer(X_val_fold.values)
        
        # ---- SHAP Beeswarm Plot 저장 ----
        plt.figure(figsize=(10, 6))  # 적절한 크기 설정
        shap.plots.beeswarm(shap_values, show=False)
        
        beeswarm_path = f"DeepSurv_beeswarm_fold_{i}.png"
        plt.tight_layout()  # 여백 자동 조정
        plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
        plt.close()  # 메모리 정리
        print(f"Beeswarm plot saved to {beeswarm_path}")

    # Final Scores (5-fold 결과 평균)
    fold_scores_df = pd.DataFrame(fold_scores).T
    final_scores = fold_scores_df.mean(skipna=True)
    print("Final DeepSurv Scores:")
    print(final_scores)

    return final_scores


In [None]:
############# 전처리 ################

df = pd.read_csv("processed_survival_data_modified.csv") # processed_survival_data_modified : Age 범주화한 데이터셋
df["implant_length_group"] = df["implant_length_group"].apply(
    lambda x: "Length ≥ 10" if x == "길이 10 이상" else "Length < 10"
)
df["survival"] = df["survival"].map({"survive": 0, "fail": 1}) # 종속변수(수술 성공 여부)를 0, 1로 변환
# 분석 제외할 변수 제거
exclude_columns = ["patient_ID", "me", "failure_reason", "failure_date", 
                   "last_fu_date", "surgery_Date", "fu_for_fail_yr", "fu_for_survival_yr"]
all_columns = [col for col in df.columns if col not in exclude_columns] # 분석에 사용할 변수만 포함
# 지정한 컬럼들에 결측치가 있는 행 제거
df = df.dropna(subset = all_columns)
df = df[all_columns]
selected_features = [col for col in df.columns if col not in ['fu_total_yr', 'survival'] + exclude_columns]
time_col = "fu_total_yr"
event_col = "survival"

In [None]:
########## 최종 실행 코드 #############
X_train, X_test, y_train, y_test, kfold = prepare_deepsurv_data(df, "fu_total_yr", "survival")
best_params = optimize_deepsurv(X_train, y_train, kfold)
fold_scores_df = train_and_evaluate_deepsurv(X_train, X_test, y_train, y_test, best_params, kfold)