In [None]:
import pickle
import datetime
from pathlib import Path
import yaml
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
import math
from typing import List, Tuple
from sklearn.model_selection import train_test_split, KFold
from sksurv.util import Surv
from sksurv.metrics import concordance_index_censored
from models import TabNetSurvivalRegressor
import optuna
from data_preprocessing import encode_selected_variables, Encoder
from sklearn.preprocessing import LabelEncoder
from pytorch_tabnet.abstract_model import TabModel
from pytorch_tabnet.utils import filter_weights

# TabNetSurvivalRegressor: TabNet 기반 생존 분석 모델 클래스 (models.py에 정의)
from models import TabNetSurvivalRegressor

# 평가 함수와 손실 함수 불러오기
from evaluation import PartialLogLikelihood, evaluate_survival_model

from sklearn.inspection import permutation_importance
import shap
from explainability import SHAP
import matplotlib.pyplot as plt
import json

In [None]:
def prepare_tabnet_data(df, time_col, event_col, cat_cols, random_state=1234, test_size=0.1, n_splits=5):
    """
    TabNet 모델 학습을 위한 데이터 준비 함수.
    
    Args:
        df: 전처리된 데이터프레임 (결측치 없는 상태)
        time_col: 생존 시간을 나타내는 컬럼명 (ex: 'fu_total_yr')
        event_col: 생존 여부를 나타내는 컬럼명 (ex: 'survival')
        cat_cols: 독립변수 중 범주형 변수명의 목록 (ex: ['Sex', 'prosthesis_type', ..., 'periodontal_diagnosis_group'])
        random_state: 재현성을 위한 랜덤 시드
        test_size: Train / test set split 비율. 0.1 디폴트
        n_splits: cross validation fold 수. 5 디폴트

    Returns:
        X_train, X_test, y_train, y_test, kfold
    """   
    # 독립변수(X)에서 time_col, event_col 제거 후 Label Encoding
    X = df.drop(columns=[time_col, event_col])
    for col in cat_cols:
        X[cat_cols] = X[cat_cols].astype("category")
    label_encoders = {}
    for col in cat_cols:
        le = LabelEncoder()
        X[col] = le.fit_transform(X[col])
        label_encoders[col] = le  # LabelEncoder 저장 (나중에 역변환 가능)

    # cat_idxs와 cat_dims 생성
    cat_idxs = [X.columns.get_loc(col) for col in cat_cols]  # 범주형 변수 인덱스
    cat_dims = [X[col].nunique() for col in cat_cols]  # 각 범주형 변수의 고유값 개수

    # 종속변수(y)를 구조화 배열로 변환
    y = pd.DataFrame({"vit_status": df[event_col].astype(bool),
                    "survival_time": df[time_col]})
    y = Surv.from_dataframe("vit_status", "survival_time", y)
    np.random.seed(random_state)
    random_state = random_state
    
    # 데이터 분할 (Train/Test Split)
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=random_state)

    # KFold 설정
    kfold = KFold(n_splits=n_splits, shuffle=True, random_state=random_state)


    return X_train, X_test, y_train, y_test, kfold, cat_idxs, cat_dims

# 함수 과정 : X & y 생성, TabNet 학습 위한 Label Encoding, train/test split, fold 정의, tabnet 학습 위한 cat_idxs, cat_dims 정의

In [None]:
def optimize_tabnet(X_train, y_train, kfold, cat_idxs, cat_dims, random_state=1234, n_trials=50):
    """
    Optuna를 사용해 TabNet의 최적 하이퍼파라미터를 찾는 함수.

    Args:
        X_train: 훈련 데이터 (독립변수)
        y_train: 훈련 데이터 (종속변수)
        kfold: K-Fold 객체
        cat_idxs: 범주형 변수의 인덱스 (열 번호)
        cat_dims: 범주형 변수의 고윳값 개수
        random_state: 랜덤 시드
        n_trials: Optuna 하이퍼파라미터 서치 횟수

    Returns:
        best_params: 최적의 하이퍼파라미터
    """


    # 설정
    random_state = random_state
    np.random.seed(random_state)
    model_name = "tabnet"
    loss_fn = PartialLogLikelihood

    config = yaml.safe_load(Path("./config.yaml").read_text())
    base_path = config["base_path"]
    

    # ------------------------------
    # Optuna Objective 함수 (TabNet)
    # ------------------------------
    def objective_tabnet(trial: optuna.Trial):
        params = {
            "n_d": trial.suggest_int("n_d",
                                     config["tabnet"]["n_d"]["min"],
                                     config["tabnet"]["n_d"]["max"]),
            "n_steps": trial.suggest_int("n_steps",
                                         config["tabnet"]["n_steps"]["min"],
                                         config["tabnet"]["n_steps"]["max"]),
            "gamma": trial.suggest_categorical("gamma",
                                               config["tabnet"]["gamma"]),
            "optimizer_params": {
                "lr": trial.suggest_categorical("lr", config["tabnet"]["lr"]),
                "weight_decay": trial.suggest_categorical("weight_decay",
                                                          config["tabnet"]["weight_decay"]),
            },
            "mask_type": trial.suggest_categorical("mask_type",
                                                   config["tabnet"]["mask_type"]),
        }  
        print(f" Selected Learning Rate (lr): {params['optimizer_params']['lr']}")  # 학습률 출력
        scores= []
        for foldnr, (train_fold, _) in enumerate(kfold.split(X_train, y_train)):
            y_train_numpy = np.stack((y_train["vit_status"][train_fold], y_train["survival_time"][train_fold]), axis=-1).astype(np.float32) #np.expand_dims(y_train["vit_status"][train_fold],1) # 
            X_train_fold = X_train.iloc[train_fold]
            
            logits = torch.tensor(np.random.randn(y_train_numpy.shape[0], 1), dtype=torch.float32)
            fail_indicator = torch.tensor(y_train_numpy[:, 0], dtype=torch.float32)
            times = torch.tensor(y_train_numpy[:, 1], dtype=torch.float32)
    
            loss = PartialLogLikelihood(logits, fail_indicator, times)
            print("Test loss:", loss.item())  # 손실이 0이 아니어야 함
    
    
            tabnet = TabNetSurvivalRegressor(seed=random_state, device_name=config["device"], 
                                     n_a=params["n_d"], cat_idxs=cat_idxs, 
                                     cat_dims=cat_dims, **params)
    
            tabnet.fit(
                X_train_fold.values, 
                y_train_numpy,
                loss_fn=PartialLogLikelihood,  # 객체 전달
                max_epochs=50,
                batch_size = min(16, len(X_train_fold))
            )
    
            with torch.no_grad():
                sample_pred = tabnet.predict(X_train_fold.values[:5])
                print("Sample predictions:", sample_pred)
            
            with torch.no_grad():
                y_pred = tabnet.predict(X_test.values)
                
                y_pred = y_pred + np.random.random(y_pred.shape) * 1e-7
                print(y_pred.shape)
                scores.append(concordance_index_censored(y_test["vit_status"], y_test["survival_time"], np.squeeze(y_pred))[0])
        trial_nr = trial.number
        score = np.mean(scores)
        return score
    
    # Optuna 실행
    study = optuna.create_study(study_name=model_name+str(datetime.datetime.now()),
                                direction="maximize",
                                sampler=optuna.samplers.TPESampler(seed=random_state))
    study.optimize(objective_tabnet, n_trials=n_trials)
    best_params = study.best_trial.params
    print("RESULTS")
    print(f"Best params: {best_params}")
    print(f"Best value: {study.best_value}")

    return best_params

In [None]:
def train_and_evaluate_tabnet(X_train, X_test, y_train, y_test, best_params, kfold, cat_idxs, cat_dims, random_state=1234):
    """
    Random Survival Forest를 학습하고 평가하는 함수.
    
    Args:
        X_train, X_test, y_train, y_test: 훈련 및 테스트 데이터
        best_params: Optuna에서 찾은 최적 하이퍼파라미터
        kfold: K-Fold 객체
        random_state: 랜덤 시드
    
    Returns:
        각 fold별 feature importance : .csv 파일로 저장, SHAP values beeswarm plot : .png파일로 저장 
        fi: K-Fold 결과 DataFrame
    """
    np.random.seed(random_state)
    config = yaml.safe_load(Path("./config.yaml").read_text())
    base_path = config["base_path"]
    
    # 각 fold의 결과를 저장할 딕셔너리 생성
    fold_scores = {}
    
    for i, (train_idx, val_idx) in enumerate(kfold.split(X_train, y_train)):
        X_train_fold = X_train.iloc[train_idx]
        X_val_fold = X_train.iloc[val_idx]
        optimizer_params = {"lr": best_params["lr"], "weight_decay": best_params["weight_decay"]}
        best_params["optimizer_params"] = optimizer_params
        best_params_subset = best_params.copy()
        best_params_subset.pop("lr")
        best_params_subset.pop("weight_decay")
        
        best_model = TabNetSurvivalRegressor(cat_idxs=cat_idxs, cat_dims=cat_dims, seed=random_state,
                                             device_name=config["device"], 
                                             n_a=best_params["n_d"], 
                                             **best_params_subset)
        y_train_numpy = np.stack((y_train["vit_status"][train_idx], y_train["survival_time"][train_idx]), axis=-1).astype(np.float32)
    
        best_model.fit(
            X_train_fold.values, y_train_numpy,
            loss_fn=PartialLogLikelihood,
            batch_size = min(16, len(X_train_fold)),
            max_epochs=50
        )

        # 평가 결과 저장
        scores = evaluate_survival_model(best_model, X_val_fold.values, y_train[train_idx], y_train[val_idx])
        print(f"Final TabNet Scores in Fold {i}: {scores}")
        fold_scores[f"fold_{i}"] = scores

        # ---- Permutation Importance 저장 ----
        result = permutation_importance(best_model, X_val_fold, y_train[val_idx], n_repeats=15, random_state=random_state)
        result_dict = {k: result[k] for k in ("importances_mean", "importances_std")}
        permutation_importances = pd.DataFrame(result_dict, index=X_val_fold.columns).sort_values(by="importances_mean", ascending=False)

        perm_imp_path = f"TabNet_permutation_importances_fold_{i}.csv"
        permutation_importances.to_csv(perm_imp_path, encoding = 'utf-8')
        print(f"Permutation importances saved to {perm_imp_path}")
        
        # ---- SHAP values 저장 ----
        explainer = shap.Explainer(
                    best_model.predict,
                    X_val_fold.values,
                    feature_names=X_val_fold.columns.tolist()  # 원본 변수명 리스트
                )
        shap_values = explainer(X_val_fold.values)

        
        # ---- SHAP Beeswarm Plot 저장 ----
        plt.figure(figsize=(10, 6))  # 적절한 크기 설정
        shap.plots.beeswarm(shap_values, show=False)
        
        beeswarm_path = f"TabNet_beeswarm_fold_{i}.png"
        plt.tight_layout()  # 여백 자동 조정
        plt.savefig(beeswarm_path, dpi=300, bbox_inches="tight")  # 잘림 방지
        plt.close()  # 메모리 정리
        print(f"Beeswarm plot saved to {beeswarm_path}")

    # Final Scores
    fold_scores = pd.DataFrame(fold_scores).T
    final_scores = fold_scores.mean(skipna=True)
    print(final_scores)

    return final_scores

In [None]:
############# 전처리 ################

df = pd.read_csv("processed_survival_data_modified.csv") # processed_survival_data_modified : Age 범주화한 데이터셋
df["implant_length_group"] = df["implant_length_group"].apply(
    lambda x: "Length ≥ 10" if x == "길이 10 이상" else "Length < 10"
)
df["survival"] = df["survival"].map({"survive": 0, "fail": 1}) # 종속변수(수술 성공 여부)를 0, 1로 변환
# 분석 제외할 변수 제거
exclude_columns = ["patient_ID", "me", "failure_reason", "failure_date", 
                   "last_fu_date", "surgery_Date", "fu_for_fail_yr", "fu_for_survival_yr"]
all_columns = [col for col in df.columns if col not in exclude_columns] # 분석에 사용할 변수만 포함
# 지정한 컬럼들에 결측치가 있는 행 제거
df = df.dropna(subset = all_columns)
df = df[all_columns]
selected_features = [col for col in df.columns if col not in ['fu_total_yr', 'survival'] + exclude_columns]
time_col = "fu_total_yr"
event_col = "survival"
cat_cols = ['Sex',
 'Age',
 'type_of_disability_Group2',
 'compliance_with_SPT',
 'Systemic_disease',
 'bone_augmentation_procedure',
 'tooth_loss_reason',
 'implant_diameter_group',
 'implant_length_group',
 'implant_site',
 'jaw',
 'prosthesis_type',
 'periodontal_diagnosis_group']

In [None]:
########## 최종 실행 코드 #############
X_train, X_test, y_train, y_test, kfold, cat_idxs, cat_dims = prepare_tabnet_data(df, "fu_total_yr", "survival", cat_cols)
best_params = optimize_tabnet(X_train, y_train, kfold, cat_idxs, cat_dims, n_trials=3)
fold_scores_df = train_and_evaluate_tabnet(X_train, X_test, y_train, y_test, best_params, kfold, cat_idxs, cat_dims)