In [1]:
import numpy as np
import random
import pandas as pd
from tqdm import tqdm 

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, TensorDataset
from torchsummary import summary
import torchinfo

from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, accuracy_score
from sklearn.model_selection import train_test_split

import matplotlib.pyplot as plt

In [2]:
def tnr_score(y_test, y_pred):
    y_t = np.array(y_test)
    y_p = np.array(y_pred)
    tn = np.sum((1-y_t)*(1-y_p))
    fp = np.sum(y_p*(1-y_t))
    if (tn + fp) == 0:
        return 0
    else:
        return tn / (tn + fp)

In [53]:
def split_train_test_val(data, target, test_size, val_size):
    nb_samples = len(target)
    nb_test = int(test_size * nb_samples)
    nb_val = int(val_size * nb_samples)

    shuffle = list(range(nb_samples))
    random.shuffle(shuffle)

    x_train, x_test, x_val, y_train, y_test, y_val = [], [], [], [], [], []
    
    nb_0_test = 0
    nb_1_test = 0
    nb_0_val = 0
    nb_1_val = 0

    for idx in shuffle:
        if nb_0_test < (nb_test//2) and target[idx]==0:
            y_test.append(0)
            x_test.append(data[idx])
            nb_0_test += 1
        elif nb_1_test < (nb_test//2) and target[idx]==1:
            y_test.append(1)
            x_test.append(data[idx])
            nb_1_test += 1
        elif nb_0_val < (nb_val//2) and target[idx]==0:
            y_val.append(0)
            x_val.append(data[idx])
            nb_0_val += 1
        elif nb_1_val < (nb_val//2) and target[idx]==1:
            y_val.append(1)
            x_val.append(data[idx])
            nb_1_val += 1
        else:
            y_train.append(target[idx])
            x_train.append(data[idx])
    
    return x_train, x_test, x_val, y_train, y_test, y_val


def prepareData(dataGroup, id_list, window_time):
    
    if dataGroup == "dataECMO":
        dataPath = "../dataECMO/"
        patients_df = pd.read_parquet(dataPath + "patients.parquet")
    else:
        dataPath = "../dataRea/"
        patients_df = pd.read_parquet(dataPath + "patients.parquet")

    finalDataPath = dataPath + "finalData/"

    data = []

    for encounterId in tqdm(id_list, total=len(id_list)):
        
        df_mask = pd.read_parquet(finalDataPath + encounterId + "/mask.parquet")
        df_dynamic = pd.read_parquet(finalDataPath + encounterId + "/dynamic.parquet")
        df_static = pd.read_parquet(finalDataPath + encounterId + "/static.parquet")
        
        # idx_variables_kept = [0,1,3,4,6,7]
        if dataGroup == "dataECMO":
            # idx_variables_kept = [0,1,2,3,4,5,6,7,8,9]
            idx_variables_kept = [0,1,2,3,4,5,6,7,8]
        else:
            idx_variables_kept = [0,1,2,3,4,5,6,7,8]

        data_patient = df_dynamic.iloc[:(window_time*24), idx_variables_kept].to_numpy()
        
        for value in df_static.to_numpy()[0]:
            new_column = np.ones(shape=(window_time*24,1)) * value
            data_patient = np.append(data_patient, new_column, axis=1)
        
        # df_dynamic_masked = df_dynamic.iloc[:(window_time*24)].mask(df_mask.iloc[:(window_time*24)] == 0)

        # # idx_variables_kept = [0,1,3,4,6,7]
        # idx_variables_kept = list(range(0,10))
        # df_dynamic_masked = df_dynamic_masked.iloc[:,idx_variables_kept]
        # df_dynamic = df_dynamic.iloc[:,idx_variables_kept]


        data.append(data_patient)
    
    return np.array(data)


def prepareDeathList(dataGroup, window_time, is_test_ECMO=False):
    if dataGroup == "dataECMO":
        dataPath = "../dataECMO/"
    else:
        dataPath = "../dataRea/"
    
    patients_df = pd.read_parquet(dataPath + "patients.parquet")

    df_death = pd.read_csv(dataPath + "delais_deces.csv")
    
    nb_patients = len(patients_df)

    target = []
    id_list = []

    for _, row in tqdm(patients_df.iterrows(), total=nb_patients):
        encounterId = str(row["encounterId"])
        
        df_mask = pd.read_parquet(dataPath + "finalData/" + encounterId + "/mask.parquet")
        total_true_values = df_mask.values.sum()
        total_values = df_mask.values.size
        percentageMissingValues = (total_values-total_true_values)/total_values * 100
        
        withdrawal_date = pd.Timestamp(row["withdrawal_date"])
        installation_date = pd.Timestamp(row["installation_date"])
        total_time_hour = (withdrawal_date - installation_date).total_seconds() / 3600 + 4

        if total_time_hour >= window_time * 24 and percentageMissingValues < 40:
            if is_test_ECMO:
                if installation_date.year < 2020:
                    id_list.append(encounterId)
                
                    delai_sortie_deces = df_death.loc[df_death["encounterId"] == int(encounterId), "delai_sortie_deces"].to_numpy()[0]
                    if delai_sortie_deces <= 1:
                        target.append(1)
                    else:
                        target.append(0)
            else:
                id_list.append(encounterId)
                
                delai_sortie_deces = df_death.loc[df_death["encounterId"] == int(encounterId), "delai_sortie_deces"].to_numpy()[0]
                if delai_sortie_deces <= 1:
                    target.append(1)
                else:
                    target.append(0)
    
    return target, id_list

In [4]:
# dataGroup = "dataECMO"
dataGroup = "dataRangueil"

window_time_days = 5
target, id_list = prepareDeathList(dataGroup, window_time_days)
data = prepareData(dataGroup, id_list, window_time_days)

100%|██████████| 2150/2150 [01:50<00:00, 19.38it/s]
100%|██████████| 1794/1794 [03:33<00:00,  8.42it/s]


In [42]:
class CNN_1D_0(nn.Module):
    def __init__(self, num_features, num_static_features):
        super(CNN_1D_0, self).__init__()
        
        self.num_features = num_features
        self.num_static_features = num_static_features

        self.conv1 = nn.Conv1d(in_channels=num_features-num_static_features, out_channels=8, kernel_size=1)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.conv2 = nn.Conv1d(in_channels=8, out_channels=16, kernel_size=1)
        self.fc1 = nn.Linear(480 , 4)
        self.fc2 = nn.Linear(4+num_static_features, 1)  

        self.dropout1 = nn.Dropout(p=0.5)

    def forward(self, x):
        cnn_input = x[:, :-self.num_static_features, :]
        static_input = x[:, -self.num_static_features:, 0]

        out = self.pool(nn.functional.relu(self.conv1(cnn_input)))
        out = self.pool(nn.functional.relu(self.conv2(out)))
        out = torch.flatten(out, 1)
        out = nn.functional.relu(self.fc1(out))
        out = torch.cat((out, static_input), dim=1)
        # out = self.dropout1(out)
        out = self.fc2(out)
        # out = torch.sigmoid(self.fc2(out))  
        return out
    
class CNN_1D_1(nn.Module):
    def __init__(self, num_features):
        super(CNN_1D_1, self).__init__()
        
        self.num_features = num_features

        self.conv1 = nn.Conv1d(in_channels=num_features, out_channels=16, kernel_size=1)
        self.pool = nn.MaxPool1d(kernel_size=2)
        self.conv2 = nn.Conv1d(in_channels=16, out_channels=32, kernel_size=1)
        self.fc1 = nn.Linear(960 , 32)
        self.fc2 = nn.Linear(32, 1)  

    def forward(self, x):

        out = self.pool(nn.functional.relu(self.conv1(x)))
        out = self.pool(nn.functional.relu(self.conv2(out)))
        out = torch.flatten(out, 1)
        out = nn.functional.relu(self.fc1(out))
        out = self.fc2(out)
        # out = torch.sigmoid(self.fc2(out))  
        return out
    
class CNN2(nn.Module):
    def __init__(self):
        super(CNN2, self).__init__()
        
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=8, kernel_size=(1, 3), padding=1)
        self.pool = nn.MaxPool2d(kernel_size=(2, 1))
        self.conv2 = nn.Conv2d(in_channels=8, out_channels=8, kernel_size=(1, 3), padding=1)
        
        self.fc1 = nn.Linear(3224, 4) 
        self.fc2 = nn.Linear(4, 1)
        
        self.dropout1 = nn.Dropout(p=0.5)
        self.dropout2 = nn.Dropout(p=0.75)

    def forward(self, x):
        x = x.unsqueeze(1)  # Add a channel dimension
        x = self.pool(nn.functional.relu(self.conv1(x)))
        x = self.pool(nn.functional.relu(self.conv2(x)))
        # x = self.pool(nn.functional.relu(self.conv3(x)))
        # x = self.dropout1(x)
        x = torch.flatten(x, 1)
        x = nn.functional.relu(self.fc1(x))
        x = self.dropout2(x)
        # x = nn.functional.sigmoid(self.fc2(x))
        x = self.fc2(x)
        return x

class LSTMModel(nn.Module):
    def __init__(self, input_size, hidden_size, num_layers, output_size, num_static_features):
        super(LSTMModel, self).__init__()
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_static_features = num_static_features
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True)
        # self.fc1 = nn.Linear(hidden_size + num_static_features, 20)
        # self.fc2 = nn.Linear(20, output_size)
        self.fc2 = nn.Linear(hidden_size + num_static_features, 1)
    
    def forward(self, x):
        lstm_input = x[:, :, :-self.num_static_features]
        static_input = x[:, 0, -self.num_static_features:]

        h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
        
        out, _ = self.lstm(lstm_input, (h0, c0))
        out = out[:, -1, :]  # Take the output of the last time step

        out = torch.cat((out, static_input), dim=1)
        # out = torch.relu(self.fc1(out))

        # out = nn.functional.sigmoid(self.fc2(out))
        out = self.fc2(out)
        return out

class LSTMModel2(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_layers, output_size, num_features_dynamic, num_features_static):
        
        super(LSTMModel2, self).__init__()
        
        self.hidden_size = hidden_size
        self.num_layers = num_layers
        self.num_features_dynamic = num_features_dynamic
        self.num_features_static = num_features_static
        
        # Create a list of LSTM layers, one for each feature
        self.lstms = nn.ModuleList([nn.LSTM(input_size, hidden_size, num_layers, batch_first=True) for _ in range(num_features_dynamic)])
        
        # Linear layer for binary classification
        self.fc = nn.Linear(hidden_size*num_features_dynamic + num_features_static, output_size)
        
    def forward(self, x):
        
        lstm_outputs = []

        for i in range(self.num_features_dynamic):
            feature_input = x[:, :, i].unsqueeze(2)  # Shape: (batch_size, seq_length, 1)
            h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device)
            lstm_out, _ = self.lstms[i](feature_input, (h0, c0))
            lstm_out = lstm_out[:, -1, :]  # Get the last time step output: Shape: (batch_size, hidden_size)
            lstm_outputs.append(lstm_out)
        
        # Concatenate the outputs from each LSTM
        out = torch.cat(lstm_outputs, dim=1)  # Shape: (batch_size, hidden_size * num_features)
        
        static_input = x[:, 0, -self.num_features_static:]
        out = torch.cat((out, static_input), dim=1)
        
        out = self.fc(out)
        # out = nn.functional.sigmoid(self.fc(out))  # Shape: (batch_size, output_size)
        
        return out

In [49]:
def train_model(data, target, num_epochs, model_name, test_size, val_size, verbose, save_path, save_model, plot_train_curves):
    
    # x_train, x_test, x_val, y_train, y_test, y_val = split_train_test_val(data, target, test_size=test_size, val_size=val_size)
    if test_size > 0:
        x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=test_size)
        while np.sum(y_test) < 2:
            x_train, x_test, y_train, y_test = train_test_split(data, target, test_size=test_size)

        if val_size > 0:
            x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=val_size)
            while np.sum(y_val) == 0:
                x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=val_size)
        else:
            x_val = np.array([])
            y_val = np.array([])
    else:
        x_test = np.array([])
        y_test = np.array([])
        x_train, x_val, y_train, y_val = train_test_split(data, target, test_size=val_size)
        while np.sum(y_val) < 2:
            x_train, x_val, y_train, y_val = train_test_split(data, target, test_size=val_size)

    num_samples = len(target)
    num_timesteps = 24 * window_time_days
    num_features = np.size(x_train,2)
    num_features_static = 3
    num_features_dynamic = num_features - num_features_static

    
    batch_size = 32

    proportion_1 = np.sum(y_train)/np.size(y_train)
    proportion_0 = 1 - proportion_1

    class_weights = torch.tensor([1/proportion_0, 1/proportion_1], dtype=torch.float32)


    # Convert data to PyTorch tensors
    x_train_tensor = torch.tensor(x_train, dtype=torch.float32)
    y_train_tensor = torch.tensor(y_train, dtype=torch.float32)
    x_val_tensor = torch.tensor(x_val, dtype=torch.float32)
    y_val_tensor = torch.tensor(y_val, dtype=torch.float32)
    x_test_tensor = torch.tensor(x_test, dtype=torch.float32)
    y_test_tensor = torch.tensor(y_test, dtype=torch.float32)

    # Create DataLoader for training and testing sets
    train_dataset = TensorDataset(x_train_tensor, y_train_tensor)
    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    val_dataset = TensorDataset(x_val_tensor, y_val_tensor)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
    test_dataset = TensorDataset(x_test_tensor, y_test_tensor)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)

    # Instantiate the model
    if model_name == "CNN_1D_0":
        model = CNN_1D_0(num_features=num_features, num_static_features=num_features_static)

        # if verbose:
        #     print(torchinfo.summary(model, input_size=(batch_size, num_features, num_timesteps)))
    
    if model_name == "CNN_1D_1":
        model = CNN_1D_1(num_features=num_features)

    elif model_name == "CNN2":
            model = CNN2()

            if verbose:
                print(torchinfo.summary(model, input_size=(batch_size, num_timesteps, num_features)))

    elif model_name == "LSTM":
        input_size = num_features-num_features_static
        hidden_size = 32
        num_layers = 2
        output_size = 1

        model = LSTMModel(input_size, hidden_size, num_layers, output_size, num_features_static)

        if verbose:
            print(torchinfo.summary(model, input_size=(batch_size, num_timesteps, num_features)))

    if model_name == "LSTM2":
        input_size = 1
        hidden_size = 16
        num_layers = 1
        output_size = 1

        model = LSTMModel2(input_size, hidden_size, num_layers, output_size, num_features_dynamic, num_features_static)

        if verbose:
            print(torchinfo.summary(model, input_size=(batch_size, num_timesteps, num_features)))


    # criterion = nn.BCELoss()  # Binary cross-entropy loss

    criterion = nn.BCEWithLogitsLoss(pos_weight=class_weights[1])
    
    optimizer = optim.Adam(model.parameters())
    
    best_val_auroc = 0
    val_auroc_list = []
    train_auroc_list = []
    
    for epoch in range(num_epochs):
        
        model.train()
        running_loss = 0.0

        predictions = []
        true_labels = []
        # Training
        for inputs, labels in train_loader:
            optimizer.zero_grad()

            inputs = inputs.permute(0, 2, 1)
            
            outputs = model(inputs).squeeze()
            loss = criterion(outputs, labels)
            
            prediction = nn.functional.sigmoid(outputs).detach()
            predictions.extend(prediction.numpy())
            
            true_labels.extend(labels.numpy())

            loss.backward()
            optimizer.step()
            running_loss += loss.item()
        
        train_auroc = roc_auc_score(true_labels, predictions)
        train_auroc_list.append(train_auroc)

        if verbose:
            print(f"Epoch {epoch+1}/{num_epochs}, Train loss: {running_loss:.4f}, Train AUROC: {train_auroc:.4f}")

        # Validation
        if np.size(y_val) > 0:
            model.eval()
            val_loss = 0.0

            predictions = []
            true_labels = []
            with torch.no_grad():
                for inputs, labels in val_loader:

                    inputs = inputs.permute(0, 2, 1)

                    outputs = model(inputs).squeeze()
                    val_loss += criterion(outputs, labels).item()

                    prediction = nn.functional.sigmoid(outputs)
                    predictions.extend(prediction.numpy())

                    true_labels.extend(labels.numpy())
        
            val_auroc = roc_auc_score(true_labels, predictions)
            val_auroc_list.append(val_auroc) 
            if verbose:
                print(f"Validation Loss: {val_loss:.4f}, Validation AUROC: {val_auroc:.4f}")

            if val_auroc > best_val_auroc:
                best_val_auroc = val_auroc
                train_auroc_at_best_val_auroc = train_auroc
                torch.save(model.state_dict(), save_path)
            

    if np.size(y_val) > 0:
        state_dict = torch.load(save_path)
        model.load_state_dict(state_dict)

        if plot_train_curves:
            plt.figure(figsize=(10, 6))

            plt.plot(range(num_epochs), train_auroc_list, label='Train AUROC', color='blue')
            plt.plot(range(num_epochs), val_auroc_list, label='Validation AUROC', color='red')

            plt.xlabel('epochs')
            plt.ylabel('auroc')
            plt.title('Train and Val AUROC = f(epoch)')

            plt.legend()

            plt.show()

    model.eval()

    true_labels = []
    predictions = []
    predictions_binary = []

    treshold = 0.5

    with torch.no_grad():
        for inputs, labels in test_loader:
            inputs = inputs.permute(0, 2, 1)
            true_labels.extend(labels.numpy())

            outputs = nn.functional.sigmoid(model(inputs))
            predictions.extend(outputs.numpy())
            predictions_binary.extend((outputs.numpy() > treshold).astype(int))
            
            # print(np.round(np.array([p[0] for p in predictions]), 1))

    auroc = roc_auc_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions_binary, zero_division=0)
    recall = recall_score(true_labels, predictions_binary, zero_division=0)
    tnr = tnr_score(true_labels, predictions_binary)
    f1 = f1_score(true_labels, predictions_binary, zero_division=0)
    accuracy = accuracy_score(true_labels, predictions_binary)
    # if verbose:
    # print("Test AUROC score:", auroc)
    if train_auroc_at_best_val_auroc > best_val_auroc:
        return auroc, precision, recall, tnr, f1, accuracy, best_val_auroc
    else:
        return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan

In [51]:
num_train = 100

save_path = "LSTMs/lstm0.pth"

aurocs = []
precisions = []
recalls = []
tnrs = []
accuracies = []
f1s = []

best_val_auroc_all_models = 0
for i in tqdm(range(num_train), total=num_train):
    
    # auroc = train_model(num_epochs=15, model_name="LSTM", test_size=0.2, val_size=0.0, verbose=False)
    auroc, precision, recall, tnr, f1, accuracy, best_val_auroc  =  train_model(data=data,
                                                                                target=target,
                                                                                num_epochs=25, 
                                                                                model_name="CNN_1D_1", 
                                                                                test_size=0.10, 
                                                                                val_size=0.10, 
                                                                                verbose=True, 
                                                                                save_path=save_path, 
                                                                                save_model=True, 
                                                                                plot_train_curves = True)
    
    if not(np.isnan(auroc)) and best_val_auroc > best_val_auroc_all_models:
        aurocs = [auroc]
        precisions = precision
        recalls = [recall]
        tnrs = [tnr]
        accuracies = [accuracy]
        f1s = [f1]
        # aurocs.append(auroc)
        # precisions.append(precision)
        # recalls.append(recall)
        # tnrs.append(tnr)
        # f1s.append(f1)
        # accuracies.append(accuracy)
        best_val_auroc_all_models = best_val_auroc
        print(f"New best val_auroc: {best_val_auroc_all_models}")
    
    print(f"Test AUROC with best model: {np.mean(aurocs):.4f}")

print(f"AUROC: {np.mean(aurocs):.4f}")
print(f"Precision: {np.mean(precisions):.4f}")
print(f"Recall: {np.mean(recalls):.4f}")
print(f"Specificity: {np.mean(tnrs):.4f}")
print(f"Accuracy: {np.mean(accuracies):.4f}")
print(f"F1 Score: {np.mean(f1s):.4f}")
print(f"num_algos: {np.size(aurocs)}")

  0%|          | 0/100 [00:00<?, ?it/s]

Epoch 1/25, Train loss: 53.8362, Train AUROC: 0.5384
Validation Loss: 6.4634, Validation AUROC: 0.6698
Epoch 2/25, Train loss: 51.2595, Train AUROC: 0.6753
Validation Loss: 5.9342, Validation AUROC: 0.7140
Epoch 3/25, Train loss: 49.1360, Train AUROC: 0.7054
Validation Loss: 5.9241, Validation AUROC: 0.7359
Epoch 4/25, Train loss: 47.3517, Train AUROC: 0.7348
Validation Loss: 5.5210, Validation AUROC: 0.7271
Epoch 5/25, Train loss: 46.1555, Train AUROC: 0.7561
Validation Loss: 5.4788, Validation AUROC: 0.7428
Epoch 6/25, Train loss: 44.9303, Train AUROC: 0.7721
Validation Loss: 5.3967, Validation AUROC: 0.7404
Epoch 7/25, Train loss: 44.3378, Train AUROC: 0.7795
Validation Loss: 5.4878, Validation AUROC: 0.7376
Epoch 8/25, Train loss: 43.7598, Train AUROC: 0.7854
Validation Loss: 5.5042, Validation AUROC: 0.7305
Epoch 9/25, Train loss: 41.9905, Train AUROC: 0.8100
Validation Loss: 5.4802, Validation AUROC: 0.7349
Epoch 10/25, Train loss: 41.0706, Train AUROC: 0.8180
Validation Loss: 5.

  0%|          | 0/100 [00:07<?, ?it/s]


KeyboardInterrupt: 

In [54]:
window_time_days = 5
target_ECMO_test, id_list_ECMO_test = prepareDeathList("dataECMO", window_time_days, is_test_ECMO=True)
data_ECMO_test = prepareData("dataECMO", id_list_ECMO_test, window_time_days)

100%|██████████| 189/189 [00:05<00:00, 31.53it/s]
100%|██████████| 69/69 [00:05<00:00, 13.08it/s]
