In [403]:
import pandas as pd
import numpy as np
import torch.optim as optim
from torch.nn import init
import torch.nn.functional as F
from torch.nn.parameter import Parameter
import math
import sklearn.metrics as metrics

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, WeightedRandomSampler
from sklearn.preprocessing import StandardScaler, MinMaxScaler
from lifelines import CoxPHFitter
from sklearn.model_selection import train_test_split
from lifelines import KaplanMeierFitter
from lifelines.utils import concordance_index
import matplotlib.pyplot as plt
from tqdm import tqdm

In [None]:
cols_x = ["trt","age","sex","ascites","hepato",
          "spiders","edema","bili","chol","albumin","copper","alk.phos","ast","trig","platelet","protime","stage"]
 
col_target = "time"

cols_xy = ["trt","age","sex","ascites","hepato","time",
          "spiders","edema","bili","chol","albumin","copper","alk.phos","ast","trig","platelet","protime","stage"]
 

col_all = ["time","status","trt","age","sex","ascites","hepato",
          "spiders","edema","bili","chol","albumin","copper","alk.phos","ast","trig","platelet","protime","stage"]

In [None]:
def auto_encode_features(df, skip_columns=None, max_unique_for_label=30):
    df = df.copy()
    skip_columns = skip_columns or []

    for col in df.columns:
        if col in skip_columns:
            print(f"跳过列: {col}")
            continue

        dtype = df[col].dtype
        if dtype == 'bool':
            df[col] = df[col].astype(int)
        elif dtype == 'object' or isinstance(df[col].iloc[0], str):
            num_unique = df[col].nunique()
            if 1 < num_unique <= max_unique_for_label:
                df[col] = df[col].astype('category').cat.codes
            elif num_unique > max_unique_for_label:
                dummies = pd.get_dummies(df[col], prefix=col)
                df = pd.concat([df, dummies], axis=1)
                df.drop(columns=[col], inplace=True)

    return df

In [None]:
df = pd.read_csv("data/Mayo.csv")
df = df[col_all]
df =df.dropna()
df = auto_encode_features(df)
scaler = StandardScaler()
df[cols_xy] = scaler.fit_transform(df[cols_xy])
x = df[cols_x].to_numpy()
y = df[col_target].to_numpy()
e = df["status"].to_numpy()
x_np = x
y_np = y
e_np = e

dfy = df[col_target][df["status"]==0]
y_sen_np = dfy.to_numpy()

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_tensor = torch.FloatTensor(x).to(device)
Y_tensor = torch.FloatTensor(y).to(device).unsqueeze(1)
E_tensor = torch.FloatTensor(e).to(device).unsqueeze(1)

In [None]:
class SurvivalDataset(Dataset):
    def __init__(self, X, Y, E):
        self.X = X
        self.Y = Y
        self.E = E
    def __len__(self):
        return len(self.X)
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], self.E[idx]

dataset = SurvivalDataset(X_tensor, Y_tensor, E_tensor)
dataloader = DataLoader(dataset, batch_size=64, shuffle=True)

In [None]:
class DiffusionModel(nn.Module):
    def __init__(self, input_dim, hidden_dim=128):
        super().__init__()
        self.time_embed = nn.Sequential(
            nn.Linear(1, hidden_dim),
            nn.SiLU(),
            nn.Linear(hidden_dim, hidden_dim)
        )
        self.net = nn.Sequential(
            nn.Linear(input_dim + hidden_dim, hidden_dim),
            nn.SiLU(),
            nn.BatchNorm1d(hidden_dim),
            nn.Linear(hidden_dim, hidden_dim),
            nn.SiLU()
        )
        self.noise_head = nn.Linear(hidden_dim, input_dim)  
        self.time_head = nn.Linear(hidden_dim, 1)          
    def forward(self, x, t):
        t_embed = self.time_embed(t.float())
        features = self.net(torch.cat([x, t_embed], dim=-1))
        noise_pred = self.noise_head(features)
        time_pred = self.time_head(features) 
        
        return noise_pred, time_pred.squeeze(-1)
    
    
class CensoringModel:
    def __init__(self):
        self.model = None       
    def fit(self, X, Y, E):
        data = np.concatenate([X, Y, E], axis=1)
        columns = [f'x{i}' for i in range(X.shape[1])] + ['time', 'event']
        df = pd.DataFrame(data, columns=columns)
        self.model = CoxPHFitter(penalizer=0.1)
        self.model.fit(df, duration_col='time', event_col='event')
    def predict_censoring_prob(self, X, Y):
        if self.model is None:
            raise ValueError("Model not fitted yet")
        df = pd.DataFrame(np.concatenate([X, Y], axis=1), 
                         columns=[f'x{i}' for i in range(X.shape[1])] + ['time'])
        return self.model.predict_survival_function(df).values[-1, :]

In [None]:
censoring_model = CensoringModel()
censoring_model.fit(x, y.reshape(-1,1), e.reshape(-1,1))
num_steps = 1000
beta_start = 0.0001
beta_end = 0.02
betas = torch.linspace(beta_start, beta_end, num_steps).to(device)
alphas = 1 - betas
alphas_cumprod = torch.cumprod(alphas, dim=0)
sqrt_alphas_cumprod = torch.sqrt(alphas_cumprod)
sqrt_one_minus_alphas_cumprod = torch.sqrt(1 - alphas_cumprod)

In [None]:
def _pl_loss(time_pred, e, censoring_model, batch_features):
    with torch.no_grad():
        time_np = time_pred.detach().cpu().numpy().reshape(-1, 1)
        features_np = batch_features.detach().cpu().numpy()
        survival_probs = censoring_model.predict_censoring_prob(features_np, time_np)
        risk_scores = -np.log(survival_probs + 1e-8)  
        risk_scores = torch.from_numpy(risk_scores).float().to(time_pred.device)
    valid_pairs = (risk_scores.unsqueeze(1) > risk_scores.unsqueeze(0)).float()
    pl_loss = -torch.mean(valid_pairs * e.float())
    
    return pl_loss

def diffusion_loss(model, x0, t, e, censoring_model):
    noise = torch.randn_like(x0)
    sqrt_alpha_cumprod_t = sqrt_alphas_cumprod[t].view(-1, 1)
    sqrt_one_minus_alpha_cumprod_t = sqrt_one_minus_alphas_cumprod[t].view(-1, 1)
    xt = sqrt_alpha_cumprod_t * x0 + sqrt_one_minus_alpha_cumprod_t * noise
    predicted_noise, time_pred = model(xt, t.float().unsqueeze(1))
    noise_loss = nn.functional.mse_loss(predicted_noise, noise)
    wasserstein_loss = torch.mean(torch.abs(x0 - predicted_noise))
    time_loss = torch.tensor(0.0, device=xt.device)
    if torch.sum(e) > 0:
        uncensored_mask = (e == 1).squeeze()  
        t_pred_uncensored = time_pred[uncensored_mask] 
        t_true_uncensored = x0[uncensored_mask, -1]   
        time_loss += F.mse_loss(t_pred_uncensored, t_true_uncensored)
    if torch.sum(1-e) > 0:
        censored_mask = (e == 0).squeeze()  
        t_pred_censored = time_pred[censored_mask]
        t_censored = x0[censored_mask, -1]  
        time_diff = t_censored - t_pred_censored
        time_loss += torch.mean(F.relu(time_diff) ** 2)  

    batch_features = x0[:, :-1] 
    partial_likelihood_loss = _pl_loss(time_pred, e, censoring_model, batch_features)

    total_loss = (noise_loss + 
                 0.2 * wasserstein_loss + 
                 0.2 * time_loss + 
                 0.1 * partial_likelihood_loss)
    
    return total_loss, {
        'noise_loss': noise_loss.item(),
        'wasserstein_loss': wasserstein_loss.item(),
        'time_loss': time_loss.item(),
        'partial_likelihood_loss': partial_likelihood_loss.item() if censoring_model else 0.0
    }

In [None]:
model = DiffusionModel(input_dim=x.shape[1] + 1).to(device)  
optimizer = optim.Adam(model.parameters(), lr=1e-3)

num_epochs = 300
loss_history = {'total': [], 'mse': [], 'wasserstein': [], 'partial_likelihood': []}

for epoch in range(num_epochs):
    epoch_loss = 0.0
    epoch_mse = 0.0
    epoch_wasserstein = 0.0
    epoch_pl = 0.0
    
    for x, y, e in dataloader:
        optimizer.zero_grad()
        x0 = torch.cat([x, y], dim=1)
        t = torch.randint(0, num_steps, (x.size(0),), device=device)
        loss, loss_components = diffusion_loss(model, x0, t, e, censoring_model)
        loss.backward()
        optimizer.step()
        epoch_loss += loss.item()
        epoch_mse += loss_components['noise_loss']
        epoch_wasserstein += loss_components['wasserstein_loss']
        epoch_pl += loss_components['partial_likelihood_loss']
    num_batches = len(dataloader)
    print(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss/num_batches:.4f}")

Epoch 1/300, Loss: 1.5302
Epoch 2/300, Loss: 1.5002
Epoch 3/300, Loss: 1.4610
Epoch 4/300, Loss: 1.4100
Epoch 5/300, Loss: 1.4975
Epoch 6/300, Loss: 1.4439
Epoch 7/300, Loss: 1.4030
Epoch 8/300, Loss: 1.4518
Epoch 9/300, Loss: 1.4531
Epoch 10/300, Loss: 1.4311
Epoch 11/300, Loss: 1.4659
Epoch 12/300, Loss: 1.4212
Epoch 13/300, Loss: 1.4794
Epoch 14/300, Loss: 1.4202
Epoch 15/300, Loss: 1.4508
Epoch 16/300, Loss: 1.4326
Epoch 17/300, Loss: 1.4084
Epoch 18/300, Loss: 1.3802
Epoch 19/300, Loss: 1.4546
Epoch 20/300, Loss: 1.4605
Epoch 21/300, Loss: 1.3951
Epoch 22/300, Loss: 1.4005
Epoch 23/300, Loss: 1.4515
Epoch 24/300, Loss: 1.3959
Epoch 25/300, Loss: 1.4198
Epoch 26/300, Loss: 1.4074
Epoch 27/300, Loss: 1.4166
Epoch 28/300, Loss: 1.3998
Epoch 29/300, Loss: 1.4062
Epoch 30/300, Loss: 1.4268
Epoch 31/300, Loss: 1.4651
Epoch 32/300, Loss: 1.4387
Epoch 33/300, Loss: 1.4602
Epoch 34/300, Loss: 1.4282
Epoch 35/300, Loss: 1.4358
Epoch 36/300, Loss: 1.3697
Epoch 37/300, Loss: 1.4083
Epoch 38/3

In [None]:
def generate_samples(model, num_samples, censoring_model, device, censoring_rate=0.3):
    x = torch.randn(num_samples, n_features + 1, device=device)
    for t in tqdm(range(num_steps-1, -1, -1), desc='Sampling'):
        t_tensor = torch.full((num_samples,), t, device=device)
        with torch.no_grad():
            pred_noise, _ = model(x, t_tensor.float().unsqueeze(1))
        alpha_t = torch.tensor(alphas[t], device=device).view(-1, 1)
        alpha_cumprod_t = torch.tensor(alphas_cumprod[t], device=device).view(-1, 1)
        beta_t = torch.tensor(betas[t], device=device).view(-1, 1)
        mean = (1 / torch.sqrt(alpha_t)) * (x - ((1 - alpha_t) / torch.sqrt(1 - alpha_cumprod_t)) * pred_noise)
        if t > 0:
            noise = torch.randn_like(x)
            x = mean + torch.sqrt(beta_t) * noise
        else:
            x = mean
    generated_features = x[:, :-1].cpu().numpy()
    generated_times = x[:, -1].cpu().numpy()
    
    generated_times = generated_times * y_np.std() + y_np.mean()
    generated_times = np.maximum(generated_times, 0)

    censored_times = np.random.exponential(scale=1.0/(1-censoring_rate), size=num_samples)
    event_indicator = (generated_times <= censored_times).astype(int)
    observed_times = np.where(event_indicator == 1, generated_times, censored_times)
    
    return generated_features, observed_times, event_indicator

In [None]:
n_features = x.shape[1]
generated_features, generated_times, generated_events = generate_samples(
    model, num_samples=500, censoring_model=censoring_model, device=device
)

  alpha_t = torch.tensor(alphas[t], device=device).view(-1, 1)
  alpha_cumprod_t = torch.tensor(alphas_cumprod[t], device=device).view(-1, 1)
  beta_t = torch.tensor(betas[t], device=device).view(-1, 1)
Sampling: 100%|██████████| 1000/1000 [00:00<00:00, 1247.96it/s]


In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
torch.manual_seed(42)
np.random.seed(42)
if device.type == 'cuda':
    torch.cuda.manual_seed_all(42)
original_X = x_np  
original_Y = y_np         
original_E = e_np    
generated_X = generated_features  
generated_Y = generated_times     
generated_events = generated_events.reshape(-1)
generated_E = generated_events   
train_idx, test_idx = train_test_split(
    np.arange(len(original_X)), test_size=0.2, random_state=42
)
X_train = np.concatenate([original_X[train_idx], generated_X])
Y_train = np.concatenate([original_Y[train_idx], generated_Y])
E_train = np.concatenate([original_E[train_idx], generated_E])
X_test = original_X[test_idx]
Y_test = original_Y[test_idx]
E_test = original_E[test_idx]
sample_weights = np.concatenate([
    np.ones(len(train_idx)), 
    np.full(len(generated_X), 0.3)
])
def standardize(X_train, X_test):
    mean = X_train.mean(axis=0)
    std = X_train.std(axis=0)
    return (X_train - mean) / (std + 1e-8), (X_test - mean) / (std + 1e-8)

X_train, X_test = standardize(X_train, X_test)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
X_train_tensor = torch.FloatTensor(X_train).to(device)
Y_train_tensor = torch.FloatTensor(Y_train).to(device).unsqueeze(1)
E_train_tensor = torch.FloatTensor(E_train).to(device).unsqueeze(1)
weights_train_tensor = torch.FloatTensor(sample_weights).to(device)
X_test_tensor = torch.FloatTensor(X_test).to(device)
Y_test_tensor = torch.FloatTensor(Y_test).to(device).unsqueeze(1)
E_test_tensor = torch.FloatTensor(E_test).to(device).unsqueeze(1)


In [None]:
class SurvivalDataset(Dataset):
    def __init__(self, X, Y, E, weights=None, generated_mask=None):
        self.X = X
        self.Y = Y
        self.E = E
        self.weights = weights if weights is not None else torch.ones(len(X))
        self.generated_mask = generated_mask if generated_mask is not None else torch.zeros(len(X))
        
    def __len__(self):
        return len(self.X)
    
    def __getitem__(self, idx):
        return self.X[idx], self.Y[idx], self.E[idx], self.weights[idx], self.generated_mask[idx]

generated_mask = np.concatenate([
    np.zeros(len(train_idx)),
    np.ones(len(generated_X))
])

train_dataset = SurvivalDataset(
    X_train_tensor, Y_train_tensor, E_train_tensor,
    weights_train_tensor, torch.FloatTensor(generated_mask).to(device)
)
test_dataset = SurvivalDataset(X_test_tensor, Y_test_tensor, E_test_tensor)

In [None]:
class cycleblock(nn.Module):
    def __init__(self, in_features, out_features, drop_res):
        super(cycleblock, self).__init__()
        self.drop_lay = nn.Dropout(drop_res)
        self.linear1 = nn.Linear(in_features, out_features)
        self.gelu_lay = nn.CELU(inplace=True)
        self.linear2 = nn.Linear(out_features, out_features)
        self.batch_lay = nn.BatchNorm1d(out_features)

    def forward(self, x):
        identity = x
        out = self.drop_lay(x)
        out = self.linear1(out)
        out += identity        
        out = self.batch_lay(out)
        out = self.gelu_lay(out)
        out = self.drop_lay(x)
        out = self.linear2(out)
        out += identity 
        out = self.batch_lay(out)
        out = self.gelu_lay(out)
        return out
    

class Mainnet(nn.Module):
    def __init__(self, input_dim, hidden_dim, drop, cycle_nb):
        super(Mainnet, self).__init__()

        self.fc1 = nn.Sequential(
            nn.Linear(input_dim, hidden_dim),
            nn.BatchNorm1d(hidden_dim),
            nn.CELU(inplace=True),
        )

        self.gelu = nn.CELU(inplace=True)


        self.blocks = nn.ModuleList([cycleblock(hidden_dim, hidden_dim, drop) for i in range(cycle_nb)])

        self.fc2 = nn.Sequential(
            nn.Dropout(drop),
            nn.Linear(hidden_dim, 1),
        )

    def forward(self, x):
      risk_score = self.fc1(x)
      for layer in self.blocks:
        risk_score = layer(risk_score)
      risk_score = self.fc2(risk_score)
      return risk_score



class Regularization(object):
    def __init__(self, order, weight_decay):

        super(Regularization, self).__init__()
        self.order = order
        self.weight_decay = weight_decay

    def __call__(self, model):

        reg_loss = 0
        for name, w in model.named_parameters():
            if 'weight' in name:
                reg_loss = reg_loss + torch.norm(w, p=self.order)
        reg_loss = self.weight_decay * reg_loss
        return reg_loss
    

class EnhancedNegativeLogLikelihood(nn.Module):
    def __init__(self, L2_reg, alpha=0.1, gen_penalty=0.1):
        super(EnhancedNegativeLogLikelihood, self).__init__()
        self.L2_reg = L2_reg
        self.alpha = alpha
        self.gen_penalty = gen_penalty
        self.reg = Regularization(order=2, weight_decay=self.L2_reg)

    def forward(self, risk_pred, y, e, model, weights, gen_mask):
        risk_pred = risk_pred.float()
        y = y.float()
        e = e.float()
        
        weights = torch.ones_like(e).float() if weights is None else weights.float()
        gen_mask = torch.zeros_like(e).bool() if gen_mask is None else gen_mask.bool() 
        
        risk_pred = risk_pred.reshape(-1, 1)
        y = y.reshape(-1, 1)
        e = e.reshape(-1, 1)
        weights = weights.reshape(-1, 1)
        gen_mask = gen_mask.reshape(-1, 1)
        with torch.no_grad():
            y_diff = y.T - y  
            mask = (y_diff <= 0).float()  
        weighted_risk = torch.exp(risk_pred) * weights
        sum_risk = torch.sum(weighted_risk * mask, dim=0) 
        sum_mask = torch.sum(mask * weights, dim=0)  

        log_sum_risk = torch.log(sum_risk / (sum_mask + 1e-8))
        nll = - (risk_pred - log_sum_risk) * e
        weighted_nll = nll * torch.where(gen_mask, self.alpha, 1.)  
        penalty = self.gen_penalty * torch.mean(risk_pred[gen_mask.squeeze()] ** 2) if gen_mask.any() else 0.0
        l2_loss = self.reg(model)
        total_loss = (torch.sum(weighted_nll) / (torch.sum(e) + 1e-8)) + penalty + l2_loss        
        return total_loss


def calculate_ibs(model, data_loader, max_time=None, n_points=20):
    all_Y = []
    all_E = []
    all_risk_scores = []
    with torch.no_grad():
        for x, y, e, _, _ in data_loader:
            risk_score = model(x)
            all_Y.append(y.cpu().numpy())
            all_E.append(e.cpu().numpy())
            all_risk_scores.append(risk_score.cpu().numpy())
    
    Y = np.concatenate(all_Y).flatten()
    E = np.concatenate(all_E).flatten()
    risk_scores = np.concatenate(all_risk_scores).flatten()
    if max_time is None:
        max_time = np.max(Y)
    times = np.linspace(0, max_time, n_points)
    kmf = KaplanMeierFitter()
    kmf.fit(Y, E)
    km_survival = kmf.predict(times).values
    risk_scores = (risk_scores - np.mean(risk_scores)) / np.std(risk_scores)
    pred_survival = np.exp(-np.clip(risk_scores.reshape(-1, 1) * times.reshape(1, -1)/max_time, -50, 50))
    brier_scores = np.zeros(len(times))
    for i, t in enumerate(times):
        observed = (Y <= t) & (E == 1)
        weights = np.ones_like(Y)
        weights[(Y < t) & (E == 0)] = 0
        squared_error = (observed - (1 - pred_survival[:, i])) ** 2
        brier_scores[i] = np.mean(weights * squared_error)
    ibs = np.trapz(brier_scores, times) / (times[-1] - times[0])
    return ibs

In [None]:
sampler = WeightedRandomSampler(weights_train_tensor, len(weights_train_tensor), replacement=True)
hidden_dim = 32
drop = 0.1
cycle_nb = 3
batch_size = 64
bestc = 0
bestibs = 1

# batch_size =train_dataset.__len__()

train_loader = DataLoader(train_dataset, batch_size=batch_size, sampler=sampler)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
model = Mainnet(input_dim=X_train.shape[1], hidden_dim = hidden_dim, drop = drop, cycle_nb = cycle_nb).to(device)
optimizer = optim.Adam(model.parameters(), lr=1e-5, weight_decay=1e-6)
loss_fn = EnhancedNegativeLogLikelihood(0).to(device)
penalty_strength = 0.2 
num_epochs = 100
train_loss_history = []
test_loss_history = []
test_c_index_history = []
for epoch in range(num_epochs):
    model.train()
    epoch_loss = 0.0
    for x, y, e, w, gen_mask in train_loader:
        optimizer.zero_grad()
        risk_score = model(x)
        train_loss = loss_fn(risk_score, y, e, model, w, gen_mask)
        threshold = 200.0  
        if train_loss > threshold:
            with torch.no_grad():
                optimizer.zero_grad()  
        else:
            optimizer.zero_grad()
            train_loss.backward()
            torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()
        
        epoch_loss += train_loss.item()
    model.eval()
    test_epoch_loss = 0.0
    with torch.no_grad(): 
        risk_scores_test = []
        Y_test_list = []
        E_test_list = []
        for x, y, e, _, _ in test_loader:
            risk_score = model(x)
            risk_scores_test.append(risk_score.cpu().numpy())
            Y_test_list.append(y.cpu().numpy())
            E_test_list.append(e.cpu().numpy())
        risk_scores_test = np.concatenate(risk_scores_test).squeeze()
        Y_test_list = np.concatenate(Y_test_list).squeeze()
        E_test_list = np.concatenate(E_test_list).squeeze()
        c_index_test = concordance_index(Y_test_list, -risk_scores_test, E_test_list)
        max_time = np.max(np.concatenate([Y_train, Y_test]))
        test_ibs = calculate_ibs(model, test_loader, max_time=max_time) % 1
        if bestc < c_index_test:
            bestc = c_index_test
        if bestibs > test_ibs:
            bestibs = test_ibs
    train_loss_history.append(epoch_loss / len(train_loader))
    test_loss_history.append(epoch_loss / len(test_loader))
    test_c_index_history.append(c_index_test) 
    print(f"Epoch {epoch+1}/{num_epochs}, "
          f"Train Loss: {epoch_loss/len(train_loader):.4f}, "
          f"Test C-index: {c_index_test:.4f}")
print(f"Test C-index: {bestc}, "
      f"Test ibs: {bestibs}")

Epoch 1/100, Train Loss: 1.5871, Test C-index: 0.7123
Epoch 2/100, Train Loss: 2.2834, Test C-index: 0.7188
Epoch 3/100, Train Loss: 3.6696, Test C-index: 0.7374
Epoch 4/100, Train Loss: 3.4174, Test C-index: 0.7319
Epoch 5/100, Train Loss: 1.9773, Test C-index: 0.7298
Epoch 6/100, Train Loss: 1.5236, Test C-index: 0.7473
Epoch 7/100, Train Loss: 2.7513, Test C-index: 0.7418
Epoch 8/100, Train Loss: 0.9013, Test C-index: 0.7527
Epoch 9/100, Train Loss: 1.4693, Test C-index: 0.7407
Epoch 10/100, Train Loss: 1.6489, Test C-index: 0.7374
Epoch 11/100, Train Loss: 2.9619, Test C-index: 0.7407
Epoch 12/100, Train Loss: 2.8295, Test C-index: 0.7418
Epoch 13/100, Train Loss: 1.4958, Test C-index: 0.7462
Epoch 14/100, Train Loss: 1.9826, Test C-index: 0.7352
Epoch 15/100, Train Loss: 1.0301, Test C-index: 0.7462
Epoch 16/100, Train Loss: 0.3178, Test C-index: 0.7396
Epoch 17/100, Train Loss: 0.3599, Test C-index: 0.7298
Epoch 18/100, Train Loss: 1.0897, Test C-index: 0.7462
Epoch 19/100, Train

In [None]:

torch.save(model, 'Mayo.pth')

In [438]:
loaded_model = torch.load('Mayo.pth')

In [None]:

def calculate_ibs(model, data_loader, max_time=None, n_points=20):
    """计算Integrated Brier Score"""
    all_Y = []
    all_E = []
    all_risk_scores = []
    with torch.no_grad():
        for x, y, e, _, _ in data_loader:
            risk_score = model(x)
            all_Y.append(y.cpu().numpy())
            all_E.append(e.cpu().numpy())
            all_risk_scores.append(risk_score.cpu().numpy())
    
    Y = np.concatenate(all_Y).flatten()
    E = np.concatenate(all_E).flatten()
    risk_scores = np.concatenate(all_risk_scores).flatten()

    if max_time is None:
        max_time = np.max(Y)
    times = np.linspace(0, max_time, n_points)
    kmf = KaplanMeierFitter()
    kmf.fit(Y, E)
    km_survival = kmf.predict(times).values
    risk_scores = (risk_scores - np.mean(risk_scores)) / np.std(risk_scores)
    pred_survival = np.exp(-np.clip(risk_scores.reshape(-1, 1) * times.reshape(1, -1)/max_time, -50, 50))
    brier_scores = np.zeros(len(times))
    for i, t in enumerate(times):
        observed = (Y <= t) & (E == 1)
        weights = np.ones_like(Y)
        weights[(Y < t) & (E == 0)] = 0
        squared_error = (observed - (1 - pred_survival[:, i])) ** 2
        brier_scores[i] = np.mean(weights * squared_error)
    
    ibs = np.trapz(brier_scores, times) / (times[-1] - times[0])
    return ibs
max_time = np.max(np.concatenate([Y_train, Y_test]))
train_ibs = calculate_ibs(loaded_model, train_loader, max_time=max_time) % 1
test_ibs = calculate_ibs(loaded_model, test_loader, max_time=max_time) % 1

print(f"\nFinal Evaluation:")
print(f"Test C-index: {bestc}")
print(f"Train IBS: {train_ibs:.4f}")
print(f"Test IBS: {test_ibs:.4f}")


Final Evaluation:
Test C-index: 0.7592997811816192
Train IBS: 0.1549
Test IBS: 0.0312
