In [24]:
import torch
import pandas as pd
import numpy as np
from sklearn.model_selection import train_test_split
import random
import datetime
from torch.utils.data import DataLoader
from lib.dataset import *
from lib.model import *
from sklearn.metrics import roc_auc_score

In [25]:
torch.manual_seed(42)
np.random.seed(42)
device = torch.device('cuda')

In [26]:
data = SepsisDataset('./data/processed/sepsis_merged_8_2_0.5_mimic3cv_mimiciv.pt')
data_min, data_max = get_data_min_max(data)

In [27]:
data_new = []
for i in range(len(data)):
    vals = (data[i][2] - data_min) / (data_max - data_min)
    data_new.append([vals, data[i][3], data[i][-2]])

In [28]:
data_exter = SepsisDataset('./data/processed/sepsis_merged_8_2_0.5_xjtu.pt')
data_train = torch.zeros(len(data_exter), data_exter[0][2].shape[0], data_exter[0][2].shape[1]).cuda()
data_train_mask = torch.zeros(len(data_exter), data_exter[0][2].shape[0], data_exter[0][2].shape[1]).cuda()
data_label = torch.zeros(len(data_exter), 1)
print(data_train.shape, data_label.shape)
for i in range(len(data_exter)):
    data_train[i,:,:] = (data_exter[i][2] - data_min) / (data_max-data_min)
    data_train_mask[i,:,:] = data_exter[i][3]
    data_label[i,:] = data_exter[i][-2]

torch.Size([4917, 12, 215]) torch.Size([4917, 1])


In [74]:
trainset, testset = train_test_split(data_new, test_size=0.2, random_state=42)
trainset, validset = train_test_split(trainset, test_size=0.125, random_state=42)
trainloader = DataLoader(trainset, shuffle=True, batch_size=128)
validloader = DataLoader(validset, shuffle=True, batch_size=128)
testloader = DataLoader(testset, shuffle=True, batch_size=128)

In [75]:
class SelfAttention(nn.Module):
    def __init__(self, input_size=215, hidden_size=128):
        super(SelfAttention, self).__init__()
        self.query = nn.Linear(input_size, hidden_size)
        self.key = nn.Linear(input_size, hidden_size)
        self.value = nn.Linear(input_size, hidden_size)
        self.softmax = nn.Softmax(dim=-1)
        self.linear = nn.Linear(hidden_size, 1)
    
    def forward(self, x):
        x = x[0]
        print(x.shape)
        q = self.query(x)
        k = self.key(x)
        v = self.value(x)

        # 计算注意力分数
        scores = torch.matmul(q, k.transpose(-2, -1))
        scores = scores / (k.size(-1) ** 0.5)  # 进行缩放
        # print(k.size(-1) ** 0.5)
        # 将分数转换为注意力概率
        attn_probs = self.softmax(scores)

        # 对值应用注意力加权和
        attended_values = torch.matmul(attn_probs, v)
        y = self.linear(attended_values)
        

        return y

class SepsisLSTM(nn.Module):

    def __init__(self, in_dim, hidden_dim, n_layer, n_classes):
        super(SepsisLSTM, self).__init__()
        self.in_dim = in_dim
        self.hidden_dim = hidden_dim
        self.n_layer = n_layer
        self.n_classes = n_classes
        self.lstm_val = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
        self.lstm_mask = nn.LSTM(in_dim, hidden_dim, n_layer, batch_first=True)
        self.classifier = nn.Sequential(nn.Linear(hidden_dim, 64), nn.ReLU(), nn.Linear(64, n_classes), nn.Sigmoid())
        # self.att1 = SelfAttention(215, 128)
        # self.att2 = SelfAttention(215, 128)

    def forward(self, x):
        x0 = x[0]
        x1 = x[1]

        out_val, (h_val, c_val) = self.lstm_val(x0)
        out_mask, (h_mask, c_mask) = self.lstm_mask(x1)
        y_val = h_val[-1, :, :]
        y_mask = h_mask[-1, :, :]
        x = self.classifier(y_val * y_mask)
        return x
    
# class SepsisAttention(nn.Module):

#     def __init__():
#         super(SepsisAttention, self).__init__()


In [76]:
def get_linear_schedule_with_warmup(optimizer, num_warmup_steps, num_training_steps, last_epoch=-1):
    """
    Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
    a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
    Args:
        optimizer (:class:`~torch.optim.Optimizer`):
            The optimizer for which to schedule the learning rate.
        num_warmup_steps (:obj:`int`):
            The number of steps for the warmup phase.
        num_training_steps (:obj:`int`):
            The total number of training steps.
        last_epoch (:obj:`int`, `optional`, defaults to -1):
            The index of the last epoch when resuming training.
    Return:
        :obj:`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
    """
    def lr_lambda(current_step: int):
        if current_step < num_warmup_steps:
            return float(current_step) / float(max(1, num_warmup_steps))
        return max(
            0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
        )
    return torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch)

In [77]:
model = SepsisLSTM(in_dim=215, hidden_dim=1024, n_layer=2, n_classes=1)

model.to(device)
loss = nn.BCELoss()
optim = torch.optim.Adamax(lr=1e-4, params=model.parameters())
scheduler = get_linear_schedule_with_warmup(optim, 0, 11700)
model

SepsisLSTM(
  (lstm_val): LSTM(215, 1024, num_layers=2, batch_first=True)
  (lstm_mask): LSTM(215, 1024, num_layers=2, batch_first=True)
  (classifier): Sequential(
    (0): Linear(in_features=1024, out_features=64, bias=True)
    (1): ReLU()
    (2): Linear(in_features=64, out_features=1, bias=True)
    (3): Sigmoid()
  )
)

In [78]:
auc_best = 0
for i in range(100):

    for vals,masks,label in trainloader:
        model.train()
        optim.zero_grad()
        y_pred = model((vals,masks))
        error = loss(y_pred, label.reshape(-1,1))
        # print(error)
        error.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1)
        optim.step()
        scheduler.step()

    with torch.no_grad():
        y_pred_all = []
        y_label_all = []
        model.eval()
        for vals,masks,label in validloader:
            y_pred_all.append(model((vals,masks)))
            y_label_all.append(label.reshape(-1,1))
        y_pred_all = torch.concat((y_pred_all))
        y_label_all = torch.concat((y_label_all))
        auc = roc_auc_score(y_label_all.cpu(), y_pred_all.cpu())
        if auc > auc_best:
            auc_best = auc
            torch.save(model.state_dict, '1.pt')
        print(auc, end=' ')
    
    with torch.no_grad():
        a = model((data_train,data_train_mask)).cpu()
        b = data_label.cpu()
        print(roc_auc_score(b, a))


0.7168186715606971 0.6771305253910718
0.7302860023544442 0.6107000514047599
0.7852997162146451 0.7320803271368476
0.8484454879352392 0.7794511540583777
0.862316091415693 0.7540336020967052
0.8703078805154922 0.7797872875664364
0.8750534833487331 0.7726358030678017
0.8587166443946861 0.6586411964300669
0.8771869335413202 0.7723936744009923
0.8787776660125498 0.7632238993406752
0.8660478052881324 0.7884617166939026
0.876840497944686 0.7963460417507277
0.8839040655817648 0.7907909845385206
0.8793483904159722 0.7395024065570364
0.8813867590997004 0.7790029760476327
0.8830485673523047 0.8038818206884133
0.8861361275372523 0.7959149931432736
0.8890974223026387 0.7985387708734606
0.8854629080914834 0.8059476101751635
0.8878259424719215 0.7945351245448296
0.889278759685252 0.8034179200369664
0.8899024378993012 0.7939420499874055
0.8913218353743694 0.8028406508709846
0.888864778561285 0.7968779635108955
0.8903931432111096 0.811156852042669
0.8928751471527561 0.8072429902675199
0.892151445074189

In [33]:
print(auc_best)
model_dict = torch.load('1.pt')
model_best = model
model_best.load_state_dict(model_dict())

0.9075913641513529


<All keys matched successfully>

In [34]:
with torch.no_grad():
    y_pred_all = []
    y_label_all = []
    model.eval()
    for vals,masks,label in trainloader:
        y_pred_all.append(model_best((vals,masks)))
        y_label_all.append(label.reshape(-1,1))
    y_pred_all = torch.concat((y_pred_all))
    y_label_all = torch.concat((y_label_all))
    auc = roc_auc_score(y_label_all.cpu(), y_pred_all.cpu())
    if auc > auc_best:
        auc_best = auc
        model_best = model
    print(auc)

0.9197170327119746


In [35]:
with torch.no_grad():
    y_pred_all = []
    y_label_all = []
    model.eval()
    for vals,masks,label in testloader:
        y_pred_all.append(model_best((vals,masks)))
        y_label_all.append(label.reshape(-1,1))
    y_pred_all = torch.concat((y_pred_all))
    y_label_all = torch.concat((y_label_all))
    auc = roc_auc_score(y_label_all.cpu(), y_pred_all.cpu())
    if auc > auc_best:
        auc_best = auc
        model_best = model
    print(auc)

0.9088249667330663


In [36]:
data_exter = SepsisDataset('./data/processed/sepsis_merged_8_2_0.5_eicu.pt')
data_train = torch.zeros(len(data_exter), data_exter[0][2].shape[0], data_exter[0][2].shape[1]).cuda()
data_train_mask = torch.zeros(len(data_exter), data_exter[0][2].shape[0], data_exter[0][2].shape[1]).cuda()
data_label = torch.zeros(len(data_exter), 1)
print(data_train.shape, data_label.shape)
for i in range(len(data_exter)):
    data_train[i,:,:] = (data_exter[i][2] - data_min) / (data_max-data_min)
    data_train_mask[i,:,:] = data_exter[i][3]
    data_label[i,:] = data_exter[i][-2]

with torch.no_grad():
    a = model_best((data_train,data_train_mask)).cpu()
    b = data_label.cpu()
print(roc_auc_score(b, a))

torch.Size([23454, 12, 215]) torch.Size([23454, 1])
0.8922810475125396


In [37]:
data_exter = SepsisDataset('./data/processed/sepsis_merged_8_2_0.5_xjtu.pt')
data_train = torch.zeros(len(data_exter), data_exter[0][2].shape[0], data_exter[0][2].shape[1]).cuda()
data_train_mask = torch.zeros(len(data_exter), data_exter[0][2].shape[0], data_exter[0][2].shape[1]).cuda()
data_label = torch.zeros(len(data_exter), 1)
print(data_train.shape, data_label.shape)
for i in range(len(data_exter)):
    data_train[i,:,:] = (data_exter[i][2] - data_min) / (data_max-data_min)
    data_train_mask[i,:,:] = data_exter[i][3]
    data_label[i,:] = data_exter[i][-2]

with torch.no_grad():
    a = model_best((data_train,data_train_mask)).cpu()
    b = data_label.cpu()
print(roc_auc_score(b, a))


torch.Size([4917, 12, 215]) torch.Size([4917, 1])
0.7952823637318731
