In [1]:
import os
import h5py
import numpy as np
import pandas as pd
import torch

from torch.utils.data import Dataset

class SurvivalDataset(Dataset):
    ''' The dataset class performs loading data from .h5 file. '''
    def __init__(self, h5_file, is_train):
        ''' Loading data from .h5 file based on (h5_file, is_train).

        :param h5_file: (String) the path of .h5 file
        :param is_train: (bool) which kind of data to be loaded?
                is_train=True: loading train data
                is_train=False: loading test data
        '''
        # loads data
        self.X, self.e, self.y = \
            self._read_h5_file(h5_file, is_train)
        # normalizes data
        self._normalize()

        print('=> load {} samples'.format(self.X.shape[0]))

    def _read_h5_file(self, h5_file, is_train):
        ''' The function to parsing data from .h5 file.

        :return X: (np.array) (n, m)
            m is features dimension.
        :return e: (np.array) (n, 1)
            whether the event occurs? (1: occurs; 0: others)
        :return y: (np.array) (n, 1)
            the time of event e.
        '''
        split = 'train' if is_train else 'test'
        with h5py.File(h5_file, 'r') as f:
            X = f[split]['x'][()]
            e = f[split]['e'][()].reshape(-1, 1)
            y = f[split]['t'][()].reshape(-1, 1)
        return X, e, y

    def _normalize(self):
        ''' Performs normalizing X data. '''
        self.X = (self.X-self.X.min(axis=0)) / (self.X.max(axis=0)-self.X.min(axis=0))

    def __getitem__(self, item):
        ''' Performs constructing torch.Tensor object'''
        # gets data with index of item
        X_item = self.X[item] # (m)
        e_item = self.e[item] # (1)
        y_item = self.y[item] # (1)
        # constructs torch.Tensor object
        X_tensor = torch.from_numpy(X_item)
        e_tensor = torch.from_numpy(e_item)
        y_tensor = torch.from_numpy(y_item)
        return X_tensor, y_tensor, e_tensor

    def __len__(self):
        return self.X.shape[0]

In [2]:
from torch.utils.data import DataLoader

# 定义数据文件路径
h5_file = './data/support/support_train_test.h5'


# 创建训练集数据集实例
train_dataset = SurvivalDataset(h5_file, is_train=True)
test_dataset = SurvivalDataset(h5_file, is_train=False)
# 可选：如果需要，你可以查看数据集的长度
print("Training dataset length:", len(train_dataset))

# 可以通过索引访问数据集中的数据
# 假设想访问第一个样本的数据
X_sample, y_sample, e_sample = train_dataset[0]

# 打印样本数据的形状（假设m为特征的维度）
print("X_sample shape:", X_sample.shape)  # 应该是 (m,)
print("y_sample shape:", y_sample.shape)  # 应该是 (1,)
print("e_sample shape:", e_sample.shape)  # 应该是 (1,)


=> load 7098 samples
=> load 1775 samples
Training dataset length: 7098
X_sample shape: torch.Size([14])
y_sample shape: torch.Size([1])
e_sample shape: torch.Size([1])


In [3]:
# 定义批次大小（batch size）
#batch_size = 32

# 创建训练集数据加载器
train_loader = DataLoader(train_dataset, batch_size=train_dataset.__len__(), shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=test_dataset.__len__(), shuffle=True)
# 遍历数据加载器中的每一个批次
for batch_idx, (X_batch, y_batch, e_batch) in enumerate(train_loader):
    # 在这里执行训练代码，例如：
    # optimizer.zero_grad()
    # outputs = model(X_batch)
    # loss = criterion(outputs, y_batch, e_batch)
    # loss.backward()
    # optimizer.step()
    
    # 可以根据需要打印每个批次的数据形状
    print(f"Batch {batch_idx}:")
    print("X_batch shape:", X_batch.shape)  # 应该是 (batch_size, m)
    print("y_batch shape:", y_batch.shape)  # 应该是 (batch_size, 1)
    print("e_batch shape:", e_batch.shape)  # 应该是 (batch_size, 1)

Batch 0:
X_batch shape: torch.Size([7098, 14])
y_batch shape: torch.Size([7098, 1])
e_batch shape: torch.Size([7098, 1])


In [4]:
from lifelines.utils import concordance_index
def c_index(risk_pred, y, e):
    ''' Performs calculating c-index

    :param risk_pred: (np.ndarray or torch.Tensor) model prediction   模型预测
    :param y: (np.ndarray or torch.Tensor) the times of event    事件e的时间
    :param e: (np.ndarray or torch.Tensor) flag that records whether the event occurs   标记，记录事件是否发生
    :return c_index: the c_index is calculated by (risk_pred, y, e)   返回计算的c指数
    '''
    if not isinstance(y, np.ndarray):
        y = y.detach().cpu().numpy()
    if not isinstance(risk_pred, np.ndarray):
        risk_pred = risk_pred.detach().cpu().numpy()
    if not isinstance(e, np.ndarray):
        e = e.detach().cpu().numpy()
    return concordance_index(y, risk_pred, e)  # 直接存在计算c指数的函数

In [5]:
class Regularization(object):
    #### 初始化
    def __init__(self, order, weight_decay):
        ''' The initialization of Regularization class正则化类的初始化

        :param order: (int) norm order number范数
        :param weight_decay: (float) weight decay rate权重衰减率（权重衰降的强度）
        :param p:默认求2范数，除非设定p=1（并未使用到）
        '''
        super(Regularization, self).__init__()  # 集成父类的属性和方法
        # 将传递的属性保存在self中
        self.order = order  # 范数
        self.weight_decay = weight_decay  # 权重衰减率

    ### 用于计算模型的正则化损失
    def __call__(self, model):  # model表示需要正则化的模型
        ''' Performs calculates regularization(self.order) loss for model.

        :param model: (torch.nn.Module object)
        :return reg_loss: (torch.Tensor) the regularization(self.order) loss返回正则化的损失
        '''
        reg_loss = 0
        for name, w in model.named_parameters():  # 遍历模型当中的参数
            if 'weight' in name:  # 如果存在weight相关字眼，就证明时需要正则化的参数，那么就计算该参数的范数
                reg_loss = reg_loss + torch.norm(w, p=self.order)  # torch.norm计算参数范数
        reg_loss = self.weight_decay * reg_loss  # 最后将计算的范数乘以权重衰减率 self.weight_decay，得到正则化损失 reg_loss
        return reg_loss


In [6]:
import torch.nn as nn
class DeepNN(nn.Module):
    def __init__(self):
        super(DeepNN, self).__init__()
        self.model = nn.Sequential(
            nn.Linear(14, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(32, 64),
            nn.BatchNorm1d(64),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(64, 32),
            nn.BatchNorm1d(32),
            nn.ReLU(),
            nn.Dropout(p=0.25),
            nn.Linear(32, 1),
        )

    def forward(self, x):
        return self.model(x)

In [7]:
class NegativeLogLikelihood(nn.Module):
    def __init__(self):
        super(NegativeLogLikelihood, self).__init__()
        self.L2_reg = 0.001
        self.reg = Regularization(order=2, weight_decay=self.L2_reg )

    def forward(self, risk_pred, y, e, model):
        # 确定设备，这里假设所有输入张量都应该在同一个设备上
        device = risk_pred.device

        # 确保所有输入张量都在正确的设备上
        y = y.to(device)
        e = e.to(device)
        model = model.to(device)  # 假设model也需要被移动到正确的设备

        # 创建蒙版mask，确保它也在正确的设备上
        mask = torch.ones(y.shape[0], y.shape[0], device=device)
        mask[(y.T - y) > 0] = 0

        # 确保风险预测指数化操作也在正确的设备上执行
        log_loss = torch.exp(risk_pred.to(device)) * mask

        # 接下来的操作...
        log_loss = torch.sum(log_loss, dim=0) / torch.sum(mask, dim=0)
        log_loss = torch.log(log_loss).reshape(-1, 1)

        neg_log_loss = -torch.sum((risk_pred - log_loss) * e) / torch.sum(e)
        l2_loss = self.reg(model)  # 确保正则化操作也在正确的设备上执行

        return neg_log_loss + l2_loss



In [8]:
def adjust_learning_rate(optimizer, epoch, lr, lr_decay_rate):
    ''' Adjusts learning rate according to (epoch, lr and lr_decay_rate)

    :param optimizer: (torch.optim object) 优化器
    :param epoch: (int) 迭代次数
    :param lr: (float) the initial learning rate 学习率
    :param lr_decay_rate: (float) learning rate decay rate 学习率衰减率
    :return lr_: (float) updated learning rate 返回一个更新好的学习率
    '''
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr / (1 + epoch * lr_decay_rate)
    return optimizer.param_groups[0]['lr']

In [9]:
def train(model, criterion, optimizer, train_loader, test_loader, device):
    # training
    best_c_index = 0  # 定义一个变量 best_c_index，用于记录训练过程中最优的 c-index
    flag = 0  # 定义一个变量 flag，用于记录训练过程中最优的 c-index 持续多少个 epoch 未更新，以便进行早停策略
    for epoch in range(1, 501):
        # adjusts learning rate
        lr = adjust_learning_rate(optimizer, epoch, 0.047, 4.169e-3)
        # train step
        model.train()  # 模型调整到训练模式
        for X, y, e in train_loader:  # 整个过程重复，直到处理完train_loader中所有数据
            # 将数据移动到指定设备
            X, y, e = X.to(device), y.to(device), e.to(device)
            # makes predictions 做预测
            risk_pred = model(X)  # 风险预测值
            train_loss = criterion(risk_pred, y, e, model)  # 训练损失函数
            train_c = c_index(-risk_pred, y, e)  # 训练出来的c指数
            # updates parameters 更新参数
            optimizer.zero_grad()  # 清空当前的梯度
            train_loss.backward()  # 计算损失函数的梯度
            optimizer.step()  # 更新模型
        
        # valid step
        model.eval()  # 模型调整到测试模式
        for X, y, e in test_loader:
            # 将数据移动到指定设备
            X, y, e = X.to(device), y.to(device), e.to(device)
            # makes predictions 做预测
            with torch.no_grad():  # 因为测试模式下，无需计算梯度，因此关闭梯度计算，以提高计算速度
                risk_pred = model(X)  # 风险预测
                valid_loss = criterion(risk_pred, y, e, model)  # 计算损失函数
                valid_c = c_index(-risk_pred, y, e)  # 测试的c指数
                if best_c_index < valid_c:
                    best_c_index = valid_c
                    flag = 0  # 如果当前c指数>最大c指数，更新c指数，flag=0
                    # saves the best model 保存预测效果最好的那一次迭代
                    torch.save({  # 并且将当前的模型保存下来
                        'model': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'epoch': epoch}, 'test.pth')
                else:
                    flag += 1
                    if flag >= 50:
                        print(f"Early stopping at epoch {epoch}, best c-index: {best_c_index}")
                        return best_c_index
        # notes that, train loader and valid loader both have one batch!!!注意，train loader和valid loader都有一个批次!!
        print('\rEpoch: {}\tLoss: {:.8f}({:.8f})\tc-index: {:.8f}({:.8f})\tlr: {:g}'.format(
            epoch, train_loss.item(), valid_loss.item(), train_c, valid_c, lr), end='', flush=False)
        print(f"\nFlag: {flag}, Best c-index: {best_c_index}")
    return best_c_index


开始训练

In [10]:
model = DeepNN()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

DeepNN(
  (model): Sequential(
    (0): Linear(in_features=14, out_features=32, bias=True)
    (1): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (2): ReLU()
    (3): Dropout(p=0.25, inplace=False)
    (4): Linear(in_features=32, out_features=64, bias=True)
    (5): BatchNorm1d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (6): ReLU()
    (7): Dropout(p=0.25, inplace=False)
    (8): Linear(in_features=64, out_features=32, bias=True)
    (9): BatchNorm1d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    (10): ReLU()
    (11): Dropout(p=0.25, inplace=False)
    (12): Linear(in_features=32, out_features=1, bias=True)
  )
)

In [11]:
criterion = NegativeLogLikelihood().to(device)

In [12]:
import torch.optim as optim
optimizer= optim.Adam(model.parameters(), lr= 0.047)

In [13]:
best_c_index = train(model,criterion,optimizer,train_loader,test_loader,device)

Epoch: 1	Loss: 0.14533335(0.03169251)	c-index: 0.51327158(0.52240192)	lr: 0.0468049
Flag: 0, Best c-index: 0.5224019231274774
Epoch: 2	Loss: 0.13854755(0.03476819)	c-index: 0.49940124(0.50117421)	lr: 0.0466114
Flag: 1, Best c-index: 0.5224019231274774
Epoch: 3	Loss: 0.04813283(0.03635558)	c-index: 0.51582288(0.50008763)	lr: 0.0464194
Flag: 2, Best c-index: 0.5224019231274774
Epoch: 4	Loss: 0.03951707(0.03416699)	c-index: 0.52283723(0.51328722)	lr: 0.0462291
Flag: 3, Best c-index: 0.5224019231274774
Epoch: 5	Loss: 0.03193799(0.03379763)	c-index: 0.53582636(0.51964970)	lr: 0.0460403
Flag: 4, Best c-index: 0.5224019231274774
Epoch: 6	Loss: 0.01644735(0.03631753)	c-index: 0.55441530(0.51872596)	lr: 0.045853
Flag: 5, Best c-index: 0.5224019231274774
Epoch: 7	Loss: 0.00224332(0.03720739)	c-index: 0.57249007(0.51803516)	lr: 0.0456673
Flag: 6, Best c-index: 0.5224019231274774
Epoch: 8	Loss: -0.00243901(0.03470489)	c-index: 0.57909415(0.54667834)	lr: 0.045483
Flag: 0, Best c-index: 0.5466783405

Epoch: 65	Loss: -0.05480414(-0.04240985)	c-index: 0.61951628(0.60919459)	lr: 0.0369792
Flag: 8, Best c-index: 0.611157439997196
Epoch: 66	Loss: -0.05618580(-0.04305890)	c-index: 0.61846217(0.60861260)	lr: 0.0368583
Flag: 9, Best c-index: 0.611157439997196
Epoch: 67	Loss: -0.05916560(-0.04392220)	c-index: 0.62111702(0.60815766)	lr: 0.0367382
Flag: 10, Best c-index: 0.611157439997196
Epoch: 68	Loss: -0.06025360(-0.04508323)	c-index: 0.62220356(0.60846363)	lr: 0.0366188
Flag: 11, Best c-index: 0.611157439997196
Epoch: 69	Loss: -0.05945902(-0.04586243)	c-index: 0.62046407(0.60862647)	lr: 0.0365003
Flag: 12, Best c-index: 0.611157439997196
Epoch: 70	Loss: -0.05699246(-0.04549984)	c-index: 0.62144637(0.60849065)	lr: 0.0363825
Flag: 13, Best c-index: 0.611157439997196
Epoch: 71	Loss: -0.05735227(-0.04415003)	c-index: 0.62132970(0.60829714)	lr: 0.0362655
Flag: 14, Best c-index: 0.611157439997196
Epoch: 72	Loss: -0.05977945(-0.04357304)	c-index: 0.62110461(0.60861844)	lr: 0.0361492
Flag: 15, Be

Epoch: 129	Loss: -0.08970614(-0.04939338)	c-index: 0.63544333(0.61357449)	lr: 0.0305631
Flag: 17, Best c-index: 0.619614206140338
Epoch: 130	Loss: -0.08364449(-0.04795608)	c-index: 0.63469541(0.61298703)	lr: 0.0304805
Flag: 18, Best c-index: 0.619614206140338
Epoch: 131	Loss: -0.09188589(-0.04617238)	c-index: 0.63771653(0.61185846)	lr: 0.0303983
Flag: 19, Best c-index: 0.619614206140338
Epoch: 132	Loss: -0.09070016(-0.04607172)	c-index: 0.63698486(0.61295745)	lr: 0.0303166
Flag: 20, Best c-index: 0.619614206140338
Epoch: 133	Loss: -0.08941592(-0.04720452)	c-index: 0.63475375(0.61515763)	lr: 0.0302352
Flag: 21, Best c-index: 0.619614206140338
Epoch: 134	Loss: -0.08813816(-0.04775147)	c-index: 0.63653189(0.61465596)	lr: 0.0301544
Flag: 22, Best c-index: 0.619614206140338
Epoch: 135	Loss: -0.08966997(-0.04463486)	c-index: 0.63643719(0.61198844)	lr: 0.0300739
Flag: 23, Best c-index: 0.619614206140338
Epoch: 136	Loss: -0.09004361(-0.03903951)	c-index: 0.63316544(0.60876667)	lr: 0.0299939
Fl

In [14]:
best_c_index

0.616195280392804