In [34]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
import matplotlib.pyplot as plt
from lifelines import KaplanMeierFitter
import torchtuples as tt
from pycox.models import CoxPH
from pycox.evaluation import EvalSurv
from pycox.models.loss import CoxPHLoss

In [35]:
print(torch.__version__)

2.5.1+cpu


## Deepsurv 모델 아키텍쳐 생성

In [21]:
# DeepSurv 모델 정의
class DeepSurv(nn.Module):
    def __init__(self, input_dim):
        super(DeepSurv, self).__init__()
        self.fc1 = nn.Linear(input_dim, 64)
        self.fc2 = nn.Linear(64, 32)
        self.fc3 = nn.Linear(32, 1)
        self.relu = nn.ReLU()
    
    def forward(self, x):
        x = self.relu(self.fc1(x))
        x = self.relu(self.fc2(x))
        return self.fc3(x)

## DataLoader 생성 (SurvivalDataset)

In [22]:
# Dataset 정의
class SurvivalDataset(Dataset):
    def __init__(self, X, y):
        self.X = torch.tensor(X, dtype=torch.float32)
        self.durations = torch.tensor(y[:, 0], dtype=torch.float32)
        self.events = torch.tensor(y[:, 1], dtype=torch.float32)
    
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.durations[idx], self.events[idx]

## colon(대장암) 데이터 불러오기

In [23]:
# 데이터 준비
path = 'C:/Users/user/Desktop/dummy_gu.csv'
data = pd.read_csv(path, index_col=0)

In [24]:
data.head()

Unnamed: 0,SERIAL_ID,sex,age,icd_10,seercode,event_inc,stime,tx_1,tx_2,tx_3,...,gu_7,gu_8,gu_9,gu_10,gu_11,gu_12,gu_13,gu_14,gu_15,gu_16
0,996,2,66,C187,4.0,1,64,1,1,0,...,0,0,0,0,0,0,0,0,0,0
1,1441,2,46,C180,1.0,0,298,1,1,0,...,0,0,0,0,0,0,0,0,0,0
2,1658,1,64,C184,1.0,1,241,1,1,0,...,0,0,0,0,0,0,0,0,0,0
3,2428,1,47,C20,2.0,1,104,1,1,0,...,0,0,0,0,0,0,0,0,0,0
4,3078,2,57,C20,1.0,0,296,1,0,0,...,0,0,0,0,0,0,0,0,0,0


## 독립변수 & 반응변수 생성
**독립변수**
* age, sex, seer_TF(전이 유무), tf_1(수술), tf_2(화학요법), tf_3(방사선요법)

**반응변수**
* survival_time, event_indicator

In [25]:
X = data[['age', 'sex', 'seer_TF', 'tx_1','tx_2','tx_3','gu_1'
          ,'gu_2','gu_3','gu_4','gu_5','gu_6'
         ,'gu_7','gu_8','gu_9','gu_10','gu_11'
         ,'gu_12','gu_13','gu_14','gu_15','gu_16']].values
y = data[['stime', 'event_inc']].values

In [26]:
print(X)
print(y)

[[66  2  1 ...  0  0  0]
 [46  2  0 ...  0  0  0]
 [64  1  0 ...  0  0  0]
 ...
 [52  1  0 ...  0  0  0]
 [56  1  1 ...  0  0  0]
 [56  1  1 ...  0  0  0]]
[[ 64   1]
 [298   0]
 [241   1]
 ...
 [ 25   0]
 [ 25   0]
 [ 25   0]]


## scaling - StrandardScaler()
* 평균 0 , 표준편차 1 로 표준화

In [27]:
scaler = StandardScaler()
X = scaler.fit_transform(X)
print(X)

[[ 0.19977202  1.21084236  0.78690251 ... -0.12624407 -0.2386643
  -0.17822556]
 [-1.4208223   1.21084236 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]
 [ 0.03771259 -0.82587134 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]
 ...
 [-0.93464401 -0.82587134 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]
 [-0.61052514 -0.82587134  0.78690251 ... -0.12624407 -0.2386643
  -0.17822556]
 [-0.61052514 -0.82587134  0.78690251 ... -0.12624407 -0.2386643
  -0.17822556]]


## train, test split

In [28]:
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [29]:
print(X_train)
print(X_test)
print(y_train)
print(y_test)

[[ 1.01006918 -0.82587134 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]
 [ 0.44286117  1.21084236  0.78690251 ... -0.12624407 -0.2386643
  -0.17822556]
 [ 0.1187423   1.21084236 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]
 ...
 [-0.28640628  1.21084236 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]
 [-0.36743599 -0.82587134  0.78690251 ... -0.12624407 -0.2386643
  -0.17822556]
 [ 0.76698003 -0.82587134 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]]
[[ 0.76698003  1.21084236  0.78690251 ... -0.12624407 -0.2386643
  -0.17822556]
 [ 1.0910989  -0.82587134  0.78690251 ... -0.12624407 -0.2386643
  -0.17822556]
 [ 0.68595032 -0.82587134  0.78690251 ... -0.12624407 -0.2386643
  -0.17822556]
 ...
 [-0.85361429  1.21084236 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]
 [-0.52949542  1.21084236 -1.27080546 ... -0.12624407 -0.2386643
  -0.17822556]
 [ 0.68595032  1.21084236  0.78690251 ... -0.12624407  4.18998572
  -0.17822556]]
[[ 69   0]
 [132   0]
 [104

In [30]:
# 데이터 로더
train_dataset = SurvivalDataset(X_train, y_train)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [31]:
# 모델 학습
# input dimension 은 독립변수들의 컬럼의 차원임.

model = DeepSurv(input_dim=X.shape[1])
criterion = CoxPHLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

In [32]:
epochs = 100

for epoch in range(epochs):
    model.train()
    for X_batch, durations, events in train_loader:
        optimizer.zero_grad()
        predictions = model(X_batch).squeeze()
        loss = criterion(predictions, durations, events)  # CoxPHLoss로 교체 가능
        loss.backward()
        optimizer.step()
    print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item()}")

Epoch 1/100, Loss: nan
Epoch 2/100, Loss: nan
Epoch 3/100, Loss: nan
Epoch 4/100, Loss: nan
Epoch 5/100, Loss: nan
Epoch 6/100, Loss: nan
Epoch 7/100, Loss: nan
Epoch 8/100, Loss: nan
Epoch 9/100, Loss: nan
Epoch 10/100, Loss: nan
Epoch 11/100, Loss: nan
Epoch 12/100, Loss: nan
Epoch 13/100, Loss: nan
Epoch 14/100, Loss: nan
Epoch 15/100, Loss: nan
Epoch 16/100, Loss: nan
Epoch 17/100, Loss: nan
Epoch 18/100, Loss: nan
Epoch 19/100, Loss: nan
Epoch 20/100, Loss: nan
Epoch 21/100, Loss: nan
Epoch 22/100, Loss: nan
Epoch 23/100, Loss: nan
Epoch 24/100, Loss: nan
Epoch 25/100, Loss: nan
Epoch 26/100, Loss: nan
Epoch 27/100, Loss: nan
Epoch 28/100, Loss: nan
Epoch 29/100, Loss: nan
Epoch 30/100, Loss: nan
Epoch 31/100, Loss: nan
Epoch 32/100, Loss: nan
Epoch 33/100, Loss: nan
Epoch 34/100, Loss: nan
Epoch 35/100, Loss: nan
Epoch 36/100, Loss: nan
Epoch 37/100, Loss: nan
Epoch 38/100, Loss: nan
Epoch 39/100, Loss: nan
Epoch 40/100, Loss: nan
Epoch 41/100, Loss: nan
Epoch 42/100, Loss: nan
E

In [None]:
# Kaplan-Meier 곡선 시각화
kmf = KaplanMeierFitter()

# 학습 데이터 생존 곡선
durations = y_train[:, 0]
events = y_train[:, 1]
kmf.fit(durations, event_observed=events)
kmf.plot_survival_function()
plt.title("Kaplan-Meier Survival Curve (Training Data)")
plt.show()

### DeepSurv 모델의 Cox Loss

DeepSurv는 딥러닝을 활용한 생존 분석(survival analysis) 모델로, 시간에 따른 생존 확률을 예측하거나 위험(hazard)을 모델링하는 데 사용됩니다. 이 모델은 학습 과정에서 **Cox Loss**를 사용하며, 이는 **Cox Proportional Hazards Model**의 수학적 기반을 딥러닝에 맞게 변형한 손실 함수입니다.

---

#### Cox Proportional Hazards Model
Cox 비례위험 모델은 생존 데이터의 위험 비율(hazard ratio)을 분석하는 데 사용됩니다.

- 위험 함수(hazard function):
  \[
  h(t | x) = h_0(t) \exp(f(x))
  \]

  **여기서**:
  - \( h(t | x) \): 특정 시간 \( t \)에 대한 조건부 위험 함수
  - \( h_0(t) \): 기준 위험 함수(baseline hazard function)
  - \( f(x) \): 입력 \( x \)에 대한 선형 예측자 (DeepSurv에서는 신경망의 출력값)

---

#### Cox Loss의 정의
Cox Loss는 생존 분석의 **우도 함수(likelihood function)**를 기반으로 합니다.

1. **Partial Likelihood**:
   \[
   L = \prod_{i \in D} \frac{\exp(f(x_i))}{\sum_{j \in R_i} \exp(f(x_j))}
   \]

   **여기서**:
   - \( D \): 사건(event)이 발생한 샘플의 집합 (예: 사망)
   - \( R_i \): \( i \)번째 사건 이후 관찰 중인 샘플(리스크 세트)
   - \( f(x) \): 모델의 예측값 (예: 위험 점수)

2. **Log Partial Likelihood**:
   로그 변환 후:
   \[
   \ell = \sum_{i \in D} \left( f(x_i) - \log \sum_{j \in R_i} \exp(f(x_j)) \right)
   \]

3. **Negative Log-Likelihood (Cox Loss)**:
   최소화 문제로 변환하기 위해 음수 부호를 붙이면:
   \[
   \text{Cox Loss} = -\ell = - \sum_{i \in D} \left( f(x_i) - \log \sum_{j \in R_i} \exp(f(x_j)) \right)
   \]

---

#### Cox Loss의 의미
- **순위 기반 학습**: 데이터의 사건 순서(order)를 보존.
- **검열 데이터 처리**: 사건이 발생하지 않은 샘플도 학습에 포함.
- **확률적 해석**: 위험도 \( \exp(f(x)) \)를 활용한 확률적 모델링.

---

#### PyTorch 기반 구현 예시

```python
import torch
import torch.nn as nn

class CoxLoss(nn.Module):
    def __init__(self):
        super(CoxLoss, self).__init__()
    
    def forward(self, hazards, events):
        """
        Cox Loss 계산
        :param hazards: 모델의 예측 위험 점수 (torch.Tensor)
        :param events: 사건 발생 여부 (torch.Tensor, 0 또는 1)
        """
        # 위험 점수의 정렬 (시간 순서에 따라)
        sorted_indices = torch.argsort(-hazards)
        hazards = hazards[sorted_indices]
        events = events[sorted_indices]
        
        # 누적 합 계산
        cumulative_hazard = torch.cumsum(torch.exp(hazards), dim=0)
        
        # Cox Loss 계산
        log_risk = torch.log(cumulative_hazard)
        loss = -torch.sum(hazards * events - log_risk * events)
        
        return loss
