In [1]:
import pickle
import datetime
from pathlib import Path
import yaml
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split, KFold
from sksurv.ensemble import RandomSurvivalForest
from sksurv.metrics import concordance_index_censored
from sksurv.util import Surv
from sklearn.inspection import permutation_importance
import optuna
import shap
from explainability import SHAP
from evaluation import evaluate_survival_model, PartialLogLikelihood
from training_survival_analysis import train_model
from models import MinimalisticNetwork

import torch
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset, Dataset


# DeepSurv 사용 위한 Dataset 클래스 정의
# 간단한 Dataset 클래스: 이미 수치형 데이터로 준비되었다고 가정. 인코딩 후 사용
class SimpleDataset(Dataset):
    def __init__(self, X, y_time, y_event):
        """
        Args:
            X: 입력 특성. Pandas DataFrame 또는 numpy array (수치형 데이터)
            y_time: 생존 시간. Pandas Series 또는 numpy array (float)
            y_event: 이벤트(실패 여부). Pandas Series 또는 numpy array (0/1 숫자)
        """
        if isinstance(X, pd.DataFrame):
            self.X = X.values.astype(np.float32)
        else:
            self.X = np.asarray(X, dtype=np.float32)
        self.y_time = np.asarray(y_time, dtype=np.float32)
        self.y_event = np.asarray(y_event, dtype=np.float32)

    def __len__(self):
        return len(self.X)

    def __getitem__(self, index):
        return (torch.tensor(self.X[index], dtype=torch.float32),
                torch.tensor(self.y_time[index], dtype=torch.float32),
                torch.tensor(self.y_event[index], dtype=torch.float32))

In [2]:
# ------------------------------
# 데이터 로딩 및 전처리 (공통)
# ------------------------------
df = pd.read_csv("processed_survival_data_modified.csv") # processed_survival_data_modified : Age 범주화한 데이터셋
df["survival"] = df["survival"].map({"survive": 0, "fail": 1}) # 종속변수(수술 성공 여부)를 0, 1로 변환

# 코드 통일성을 위해 생존기간, 성공여부 변수 이름 변경
time_col = "fu_total_yr"
event_col = "survival"
df.rename(columns={event_col: "vit_status", time_col: "survival_time"}, inplace=True)

# 분석 제외할 변수 제거
exclude_columns = ["patient_ID", "me", "failure_reason", "failure_date", 
                   "last_fu_date", "surgery_Date", "fu_for_fail_yr", "fu_for_survival_yr"]
selected_features = [col for col in df.columns if col not in ["vit_status", "survival_time"] + exclude_columns] # 분석에 사용할 변수만 포함

# 지정한 컬럼들에 결측치가 있는 행 제거
df = df.dropna(subset = selected_features)

# 평가 함수가 기대하는 구조화 배열 생성 (필드명: 'vit_status', 'survival_time')
y = np.zeros(df.shape[0], dtype=[('vit_status', '?'), ('survival_time', '<f8')])
y['vit_status'] = df["vit_status"].values.astype(bool)
y['survival_time'] = df["survival_time"].values.astype(float)

X = pd.get_dummies(df[selected_features], drop_first=True) # get_dummies : 범주형 변수에 대해 One-hot encoding
X = X.astype(np.float32)

In [3]:
# 설정
random_state = 1234
np.random.seed(random_state)
model_name = "deepsurv"

# ------------------------------
# 데이터 분할 및 K-Fold 설정
# ------------------------------

# Split between Test and Training for Hyperparameter Tuning
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.1, random_state=random_state)
kfold = KFold(n_splits=5, shuffle=True, random_state=random_state)

# DeepSurv 입력값에 맞게 tensor 형태로 변환
X_train_tensor = torch.tensor(X_train.values, dtype=torch.float32)
y_train_time = pd.Series(y_train['survival_time'])
y_train_event = pd.Series(y_train['vit_status'])
y_train_time_tensor = torch.tensor(np.ascontiguousarray(y_train_time.values), dtype=torch.float32)
y_train_event_tensor = torch.tensor(np.ascontiguousarray(y_train_event.values), dtype=torch.float32)
train_dataset = TensorDataset(X_train_tensor, y_train_time_tensor, y_train_event_tensor)

# ------------------------------
# config.yaml 불러오기 및 RSF 파라미터 갱신
# ------------------------------
config = yaml.safe_load(Path("./config.yaml").read_text())
base_path = config["base_path"]
deepsurv_config = config["deep_surv"]
device = config["device"]

stable_params = {
    "device": device,
    "input_dim": X_train.shape[1],
    "loss_fn": PartialLogLikelihood,
    "epochs": 300, # epochs 300개 사용
    "model": "minimalistic_network"
    }

In [4]:
# ------------------------------
# Optuna Objective 함수 정의 (DeepSurv)
# ------------------------------

def objective_deep_surv(trial: optuna.Trial, stable_params):
    # Hyperparameter 범위 설정
    flexible_params = {
        "batch_size": trial.suggest_categorical("batch_size", deepsurv_config["batch_size"]),
        "inner_dim": trial.suggest_categorical("inner_dim", deepsurv_config["inner_dim"]),
        "lr": trial.suggest_categorical("lr", deepsurv_config["lr"]),
        "weight_decay": trial.suggest_categorical("weight_decay", deepsurv_config["weight_decay"])
    }
    params = {**stable_params, **flexible_params}
    scores = []

    
    # dataset_test는 전체 테스트 세트를 사용 (전역 X_test)
    dataset_test = torch.Tensor(X_test.values)
    
    for train_idx, val_idx in kfold.split(X_train, y_train):
        X_train_fold = X_train.iloc[train_idx]
        X_val_fold = X_train.iloc[val_idx]

        # Class 변환
        dataset_train = SimpleDataset(
            X_train_fold,
            pd.Series(y_train['survival_time'][train_idx]),
            pd.Series(y_train['vit_status'][train_idx])
        )
        
        # 모델 생성 및 학습 (train_model 함수 호출)
        model, losses, test_eval = train_model(dataset_train, params, trial=trial) # K-fold training data로 model fitting
        model.eval()
        y_pred = model(dataset_test.to(params["device"])).detach().cpu().numpy() # test data로 예측값 계산
        y_pred = y_pred + np.random.random(y_pred.shape) * 1e-7
        
        # 평가: 여기서는 각 fold에 대해 evaluate_survival_model의 c_index (Concordance Index)를 사용
        try:
            fold_score = concordance_index_censored(y_test["vit_status"], y_test["survival_time"], np.squeeze(y_pred))[0]
        except ValueError as e:
            print(f"Fold evaluation skipped due to error: {e}") # 예측값에 성공/실패 중 하나가 아예 없는 경우 오류 발생. 해당 과정 스킵
            continue
        scores.append(fold_score)
    
    if len(scores) == 0:
        # 모든 fold에서 오류가 발생한 경우, 기본값 반환 (예: 0)
        score = 0.0
    else:
        score = np.mean(scores)
    trial_nr = trial.number
    print(f"Trial {trial_nr}: {score}")
    
    return score

In [None]:
# Optuna 스터디 생성 및 최적화
study = optuna.create_study(study_name=model_name+str(datetime.datetime.now()),
                            direction="maximize",
                            sampler=optuna.samplers.TPESampler(seed=random_state))
study.optimize(lambda trial: objective_deep_surv(trial, stable_params), n_trials=50)
best_params = study.best_trial.params
print("Best DeepSurv params:", best_params) # 50번의 Trial로 찾은 최선의 Hyperparameter

[I 2025-03-05 20:44:02,460] A new study created in memory with name: deepsurv2025-03-05 20:44:02.458553
[I 2025-03-05 20:44:25,063] Trial 0 finished with value: 0.588157894736842 and parameters: {'batch_size': 3540, 'inner_dim': 64, 'lr': 0.001, 'weight_decay': 5}. Best is trial 0 with value: 0.588157894736842.


Trial 0: 0.588157894736842


[I 2025-03-05 20:44:46,568] Trial 1 finished with value: 0.475 and parameters: {'batch_size': 512, 'inner_dim': 16, 'lr': 5e-05, 'weight_decay': 5}. Best is trial 0 with value: 0.588157894736842.


Trial 1: 0.475


[I 2025-03-05 20:45:07,898] Trial 2 finished with value: 0.5526315789473685 and parameters: {'batch_size': 1024, 'inner_dim': 16, 'lr': 0.0005, 'weight_decay': 1}. Best is trial 0 with value: 0.588157894736842.


Trial 2: 0.5526315789473685


[I 2025-03-05 20:45:29,191] Trial 3 finished with value: 0.5144736842105264 and parameters: {'batch_size': 512, 'inner_dim': 8, 'lr': 0.0001, 'weight_decay': 5}. Best is trial 0 with value: 0.588157894736842.


Trial 3: 0.5144736842105264


[I 2025-03-05 20:45:50,656] Trial 4 finished with value: 0.6026315789473684 and parameters: {'batch_size': 1024, 'inner_dim': 32, 'lr': 5e-05, 'weight_decay': 1}. Best is trial 4 with value: 0.6026315789473684.


Trial 4: 0.6026315789473684


[I 2025-03-05 20:46:11,982] Trial 5 finished with value: 0.5776315789473684 and parameters: {'batch_size': 1770, 'inner_dim': 64, 'lr': 0.0005, 'weight_decay': 0.001}. Best is trial 4 with value: 0.6026315789473684.


Trial 5: 0.5776315789473684


[I 2025-03-05 20:46:33,234] Trial 6 finished with value: 0.4723684210526316 and parameters: {'batch_size': 512, 'inner_dim': 8, 'lr': 0.001, 'weight_decay': 5}. Best is trial 4 with value: 0.6026315789473684.


Trial 6: 0.4723684210526316


[I 2025-03-05 20:46:54,684] Trial 7 finished with value: 0.5394736842105263 and parameters: {'batch_size': 3540, 'inner_dim': 32, 'lr': 0.0005, 'weight_decay': 1}. Best is trial 4 with value: 0.6026315789473684.


Trial 7: 0.5394736842105263


[I 2025-03-05 20:47:16,015] Trial 8 finished with value: 0.5394736842105263 and parameters: {'batch_size': 3540, 'inner_dim': 16, 'lr': 1e-05, 'weight_decay': 1}. Best is trial 4 with value: 0.6026315789473684.


Trial 8: 0.5394736842105263


[I 2025-03-05 20:47:37,314] Trial 9 finished with value: 0.6013157894736842 and parameters: {'batch_size': 1770, 'inner_dim': 32, 'lr': 5e-05, 'weight_decay': 0.1}. Best is trial 4 with value: 0.6026315789473684.


Trial 9: 0.6013157894736842


[I 2025-03-05 20:47:58,746] Trial 10 finished with value: 0.6013157894736842 and parameters: {'batch_size': 1024, 'inner_dim': 32, 'lr': 5e-05, 'weight_decay': 0.01}. Best is trial 4 with value: 0.6026315789473684.


Trial 10: 0.6013157894736842


[I 2025-03-05 20:48:20,058] Trial 11 finished with value: 0.6013157894736842 and parameters: {'batch_size': 1770, 'inner_dim': 32, 'lr': 5e-05, 'weight_decay': 0.1}. Best is trial 4 with value: 0.6026315789473684.


Trial 11: 0.6013157894736842


[I 2025-03-05 20:48:41,341] Trial 12 finished with value: 0.6013157894736842 and parameters: {'batch_size': 1024, 'inner_dim': 32, 'lr': 5e-05, 'weight_decay': 0.1}. Best is trial 4 with value: 0.6026315789473684.


Trial 12: 0.6013157894736842


[I 2025-03-05 20:49:02,737] Trial 13 finished with value: 0.6013157894736842 and parameters: {'batch_size': 1770, 'inner_dim': 32, 'lr': 5e-05, 'weight_decay': 0.0}. Best is trial 4 with value: 0.6026315789473684.


Trial 13: 0.6013157894736842


[I 2025-03-05 20:49:24,046] Trial 14 finished with value: 0.6092105263157894 and parameters: {'batch_size': 1024, 'inner_dim': 32, 'lr': 1e-05, 'weight_decay': 0.1}. Best is trial 14 with value: 0.6092105263157894.


Trial 14: 0.6092105263157894


[I 2025-03-05 20:49:45,322] Trial 15 finished with value: 0.6092105263157894 and parameters: {'batch_size': 1024, 'inner_dim': 32, 'lr': 1e-05, 'weight_decay': 0.001}. Best is trial 14 with value: 0.6092105263157894.


Trial 15: 0.6092105263157894


[I 2025-03-05 20:50:06,740] Trial 16 finished with value: 0.6092105263157894 and parameters: {'batch_size': 1024, 'inner_dim': 32, 'lr': 1e-05, 'weight_decay': 0.001}. Best is trial 14 with value: 0.6092105263157894.


Trial 16: 0.6092105263157894


[I 2025-03-05 20:50:28,011] Trial 17 finished with value: 0.6078947368421053 and parameters: {'batch_size': 1024, 'inner_dim': 8, 'lr': 1e-05, 'weight_decay': 0.001}. Best is trial 14 with value: 0.6092105263157894.


Trial 17: 0.6078947368421053


In [None]:
# 최종 모델 학습
for i, (train_fold , val_fold) in enumerate(kfold.split(X_train, y_train)):
    X_train_fold = X_train.iloc[train_fold]
    X_val_fold = X_train.iloc[val_fold]
    dataset_train = SimpleDataset(X_train_fold, y_train["vit_status"][train_fold], y_train["survival_time"][train_fold])
    best_model, losses, test_eval = train_model(dataset_train, {**stable_params, **best_params})
    best_model.eval()
    y_pred = best_model(torch.Tensor(X_val_fold.values).to(stable_params["device"])).detach().cpu().numpy()
    scores = evaluate_survival_model(best_model, X_val_fold.values, y_train[train_fold],
                                                y_train[val_fold])
    print(f"Final DeepSurv Scores in Fold {i}: {scores}")

    result = permutation_importance(
    best_model, X_val_fold, y_train[val_fold], n_repeats=15, random_state=random_state,
            )
    result_dict = {k: result[k] for k in ("importances_mean", "importances_std")}
    permutation_importances = pd.DataFrame(
            result_dict,
            index=X_val_fold.columns).sort_values(by="importances_mean", ascending=False)
    print(f"Permutation Importances:")
    print(permutation_importances)
    explainer = shap.Explainer(best_model.predict, X_val_fold)
    shap_values = explainer(X_val_fold)
    shap.plots.beeswarm(shap_values)
