## Import

In [None]:
import os
import csv
import sys
import pickle
import shutil
import itertools

import numpy as np
import pandas as pd
from scipy import interp

import seaborn as sns
from seaborn import heatmap
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap

from sklearn.calibration import calibration_curve
from sklearn.calibration import CalibratedClassifierCV

from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn.model_selection import StratifiedKFold
from sklearn.metrics import precision_recall_curve
from sklearn.metrics import classification_report
from sklearn.metrics import confusion_matrix, accuracy_score, auc 
from sklearn.metrics import roc_curve, matthews_corrcoef, roc_auc_score
from sklearn.metrics import mean_squared_error, average_precision_score

from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F

from torch.utils.data.dataset import random_split
from torch.utils.data.sampler import WeightedRandomSampler
from torch.utils.data import Dataset, TensorDataset, DataLoader
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

from netcal.scaling import TemperatureScaling
from sklearn.linear_model import LogisticRegression as LR

from shap import summary_plot

from captum.attr import LayerConductance, LayerActivation, LayerIntegratedGradients
from captum.attr import IntegratedGradients, DeepLift, GradientShap, NoiseTunnel, FeatureAblation

pd.set_option('display.max_columns', 500)

torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(42)

## Reading Data

In [None]:
all_data_real = pd.read_csv('data/all_data.csv')

all_data_real.head()

In [None]:
fake_data = pd.read_csv('data/generation_0/fake_data.csv')

fake_data.head()

In [None]:
fake_data['ethnicity'] = 1

In [None]:
print(all_data_real.shape)

print(fake_data.shape)

In [None]:
fake_data.columns = all_data_real.columns
fake_data.head()

In [None]:
# drop SpO2 numeric
fake_data = fake_data.drop('SpO2', axis=1)
fake_data = fake_data.rename(columns={'SpO2.1': 'SpO2'})


all_data_real = all_data_real.drop('SpO2', axis=1)
all_data_real = all_data_real.rename(columns={'SpO2.1': 'SpO2'})

In [None]:
print(all_data_real.shape)

print(fake_data.shape)

In [None]:
# Get summary statistics for both dataframes
fake_data_summary = fake_data.describe()
all_data_real_summary = all_data_real.describe()

# Compare means and standard deviations
comparison = pd.DataFrame({
    'fake_mean': fake_data_summary.loc['mean'],
    'real_mean': all_data_real_summary.loc['mean'],
    'fake_std': fake_data_summary.loc['std'],
    'real_std': all_data_real_summary.loc['std']
})

# Calculate differences
comparison['mean_diff'] = abs(comparison['fake_mean'] - comparison['real_mean'])
comparison['std_diff'] = abs(comparison['fake_std'] - comparison['real_std'])

# Display the comparison
print(comparison)

# Optionally, you can set a threshold for significant differences
threshold = 0.1  # Adjust this value as needed
significant_diff = comparison[(comparison['mean_diff'] > threshold) | (comparison['std_diff'] > threshold)]
print("\nVariables with significant differences:")
print(significant_diff)

In [None]:
# Merge the datasets
fake_data['data_type'] = 'fake'
all_data_real['data_type'] = 'real'

merged_data = pd.concat([all_data_real, fake_data], ignore_index=True)

print(merged_data['data_type'].value_counts())
print(merged_data.shape)

In [None]:
# Count the ethnicities
ethnicity_counts = merged_data['ethnicity'].value_counts()

print("Ethnicity counts:")
print(ethnicity_counts)


In [None]:
# Add patient IDs and timestamps (bloc)
merged_data['patient_id'] = merged_data.index // 15

merged_data['bloc'] = merged_data.groupby('patient_id').cumcount() + 1

# Reorder columns to put 'patient_id' and 'bloc' first
cols = merged_data.columns.tolist()
cols = ['patient_id', 'bloc'] + [col for col in cols if col not in ['patient_id', 'bloc']]
merged_data = merged_data[cols]

merged_data.head(20)

In [None]:
# Fix 'age' as a static variable for fake data
mean_age = merged_data.groupby('patient_id')['age'].transform('mean')

merged_data['age'] = mean_age



In [None]:
# Change categorical variables to numeric with mean of the corresponding interval

variable_dict = torch.load("data/A001_BTS_nonFloat")

def get_interval_mean(value, quantiles):
    if not quantiles:  # For binary variables or those without quantiles
        return value
    for i in range(len(quantiles) - 1):
        if value <= i:  # Map the integer encoding to the correct interval
            return (quantiles[i] + quantiles[i+1]) / 2
    return (quantiles[-2] + quantiles[-1]) / 2  # For values in the last interval

for name, type_, quantiles in zip(variable_dict['Name'], variable_dict['Type'], variable_dict['Quantiles']):
    if name in merged_data.columns:
        if type_ in ['bin', 'GCS']:  # Binary or GCS, no transformation needed
            continue
        elif type_ in ['Temp_C', 'cat']:
            merged_data[name] = merged_data[name].apply(lambda x: get_interval_mean(x, quantiles))

## Labeling

In [None]:
merged_data['label'] = np.nan

merged_data.loc[(merged_data.Arterial_lactate > 4), 'label'] = 1

merged_data.loc[merged_data.label != 1, 'label'] = 0

In [None]:
# count 'lactate' > 4 for each timestep
filtered_df = merged_data[merged_data['label'] == 1]

bloc_counts = filtered_df.groupby('bloc').size()

bloc_counts_df = bloc_counts.reset_index(name='count')

print(bloc_counts_df)

In [None]:
# Filter the DataFrame where 'bloc' is 15
bloc_15_df = merged_data[merged_data['bloc'] == 15]

# Determine which patients have a 'lactate' value greater than 4 in their last timestep
patients_with_high_lactate = bloc_15_df[bloc_15_df['Arterial_lactate'] > 4]['patient_id']

# Create a new label column in the original DataFrame
merged_data['label'] = merged_data['patient_id'].apply(lambda x: 1 if x in patients_with_high_lactate.values else 0)

print(merged_data['label'].describe())

In [None]:
# Count positive labels for each ethnicity
filtered_df = merged_data[merged_data['label'] == 1]

ethnicity_counts = filtered_df.groupby('ethnicity')['patient_id'].nunique().reset_index(name='count')

print(ethnicity_counts)

In [None]:
# Drop the last timesteps
merged_data = merged_data[merged_data['bloc'] != 15]

# Imputation

In [None]:
# Check for any NaN in any column

has_nan = merged_data.isna().any().any()  

if has_nan:
    print("There are NaN values in the DataFrame.")
else:
    print("No NaN values found!")

## Normalization

In [None]:
var_list = [
    'age', 'HR', 'SysBP', 'MeanBP', 'DiaBP', 'RR', 'Potassium', 'Sodium', 'Chloride', 'Calcium',
    'Ionised_Ca', 'CO2_mEqL', 'Albumin', 'Hb', 'Arterial_pH', 'Arterial_BE', 'HCO3', 'FiO2_1',
    'Glucose', 'BUN', 'Creatinine', 'Magnesium', 'SGOT', 'SGPT', 'Total_bili', 'WBC_count',
    'Platelets_count', 'paO2', 'paCO2', 'Arterial_lactate', 'input_total', 'input_4hourly',
    'max_dose_vaso', 'output_total', 'output_4hourly', 
    'gender', 're_admission', 'mechvent',
    'Temp_C', 
    'GCS', 
    'SpO2', 'PTT', 'PT', 'INR'
]

norm = [
    'age', 'HR', 'SysBP', 'MeanBP', 'DiaBP', 'RR', 'Potassium', 'Sodium', 'Chloride', 'Calcium',
    'Ionised_Ca', 'CO2_mEqL', 'Albumin', 'Hb', 'Arterial_pH', 'Arterial_BE', 'HCO3', 'FiO2_1',
    'Glucose', 'BUN', 'Creatinine', 'Magnesium', 'SGOT', 'SGPT', 'Total_bili', 'WBC_count',
    'Platelets_count', 'paO2', 'paCO2', 'Arterial_lactate', 'input_total', 'input_4hourly',
    'max_dose_vaso', 'output_total', 'output_4hourly', 
    #'gender', 're_admission', 'mechvent',
    'Temp_C', 'GCS', 'SpO2', 'PTT', 'PT', 'INR'
]

In [None]:
print(len(var_list))
print(len(norm))


In [None]:
# Normalize numerical variables

def Normalizing(train_df, valid_df, test_df):

    train_val = train_df[norm]
    test_val = test_df[norm]
    valid_val = valid_df[norm]

    # Create and fit scalers
    scaler = StandardScaler().fit(train_val.values)

    # Transform 
    train_normalized = scaler.transform(train_val.values)

    test_normalized = scaler.transform(test_val.values)

    valid_normalized = scaler.transform(valid_val.values)

    # Update dataframes with normalized values
    train_df[norm] = train_normalized

    test_df[norm] = test_normalized

    valid_df[norm] = valid_normalized

    return train_df, valid_df, test_df

## Data Preparation

In [None]:
def Data_Prepare(df):
    
    
    data = df.groupby(['patient_id'])
    
    X = []

    y = []

    ethnicity = []
    
    for _, frame in data:

        X.append(frame[var_list].values)

        y.append(frame.iloc[0]['label'])

        ethnicity.append(frame.iloc[0]['ethnicity'])
        
    
    return X, y, ethnicity

## Device

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

#kwargs = {'num_workers': 1, 'pin_memory': True} if device=='cuda' else {}
kwargs = {}

## Dataset - Dataloader

In [None]:
class LactateData(Dataset):
    
    def __init__(self, X, y, l):
        
        self.X = torch.FloatTensor(X.astype('float'))
        self.y = torch.FloatTensor(y.astype('float'))
        self.l = l
        
    
    def __len__(self):
        return len(self.X)
    
    
    def __getitem__(self, index):
        
        X = self.X[index]
        y = self.y[index]
        l = self.l[index]
        
        return X, y, l

# Model

In [None]:
class RNN(nn.Module):
    
    def __init__(self, input_size, hidden_size, num_layers, num_classes):
        
        super(RNN, self).__init__()
        
        
        self.hidden_size = hidden_size
        
        self.num_layers  = num_layers
                
        self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, bidirectional=False)

        self.drop = nn.Dropout(p=0.5)
        
        self.fc   = nn.Linear(hidden_size, num_classes) 
        
        
    def forward(self, x, seq_lengths):
        
        
        packed_input = pack_padded_sequence(x, seq_lengths.to('cpu'), batch_first=True, enforce_sorted=False)
        
        
        out, (h, c) =  self.lstm(packed_input) 
        
        
        output = self.drop(h[-1])
            
        output = self.fc(output)
        
        
        return output

## Hyperparameters

In [None]:
input_size  = 44

hidden_size = 256

number_layers  = 1

number_classes = 2

num_epochs = 30

batch_size = 64

learning_rate = 0.0008

In [None]:
def build_model(input_size, hidden_size, number_layers, number_classes):
    
    model = RNN(input_size, hidden_size, number_layers, number_classes)

    model = model.to(device)
    
    return model

## Weight Initialization

In [None]:
def weight_init(m):
    
    if isinstance(m, nn.Linear):
        nn.init.xavier_normal_(m.weight.data)
        nn.init.normal_(m.bias.data)
        
    elif isinstance(m, nn.LSTM):
        for param in m.parameters():
            if len(param.shape) >= 2:
                nn.init.xavier_uniform_(param.data)
            else:
                nn.init.normal_(param.data)
                
    elif isinstance(m, nn.GRU):
        for param in m.parameters():
            if len(param.shape) >= 2:
                nn.init.xavier_uniform_(param.data)
            else:
                nn.init.normal_(param.data)

## Confusion Plot

In [None]:
def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Oranges):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    Source: http://scikit-learn.org/stable/auto_examples/model_selection/plot_confusion_matrix.html
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    # Plot the confusion matrix
    plt.figure(figsize = (5, 5))
    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title, size = 12)
    plt.colorbar(aspect=4)
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45, size = 14)
    plt.yticks(tick_marks, classes, size = 14)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    
    # Labeling the plot
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt), fontsize = 12,
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")
        
    plt.grid(None)
    plt.tight_layout()
    plt.ylabel('True label', size = 12)
    plt.xlabel('Predicted label', size = 12)
    plt.show()

## Calibration Plot

In [None]:
def plot_calibration_curve(name, fig_index, true_labels, probs):
    """Plot calibration curve for est w/o and with calibration. """

    fig = plt.figure(fig_index, figsize=(8, 6))
    ax1 = plt.subplot2grid((3, 1), (0, 0), rowspan=2)
    ax2 = plt.subplot2grid((3, 1), (2, 0))
    
    ax1.plot([0, 1], [0, 1], "k:", label="Perfectly calibrated")
    
    frac_of_pos, mean_pred_value = calibration_curve(true_labels, probs, n_bins=10)

    ax1.plot(mean_pred_value, frac_of_pos, "s-", label=f'{name}')
    ax1.set_ylabel("Fraction of positives")
    ax1.set_ylim([-0.05, 1.05])
    ax1.legend(loc="lower right")
    ax1.set_title(f'Calibration plot ({name})')
    
    ax2.hist(probs, range=(0, 1), bins=10, label=name, histtype="step", lw=2)
    ax2.set_xlabel("Mean predicted value")
    ax2.set_ylabel("Count")

## Model With Temperature

In [None]:
class ModelWithTemperature(nn.Module):
    """
    A thin decorator, which wraps a model with temperature scaling
    model (nn.Module):
        A classification neural network
        NB: Output of the neural network should be the classification logits,
            NOT the softmax (or log softmax)!
    """
    
    
    def __init__(self, model):
        
        super(ModelWithTemperature, self).__init__()
        
        self.model = model
        
        self.temperature = nn.Parameter(torch.ones(1) * 1.1)
        

    def forward(self, inputs, length):
        
        logits = self.model(inputs, length)
        
        return self.temperature_scale(logits)
    

    def temperature_scale(self, logits):
        """
        Perform temperature scaling on logits
        """
        
        # Expand temperature to match the size of logits
        
        temperature = self.temperature.unsqueeze(1).expand(logits.size(0), logits.size(1))
        
        
        return logits / temperature
    

    # This function probably should live outside of this class, but whatever
    def set_temperature(self, valid_loader):
        """
        Tune the tempearature of the model (using the validation set).
        We're going to set it to optimize NLL.
        valid_loader (DataLoader): validation set loader
        """
        
        self.cuda()
        
        nll_criterion = nn.CrossEntropyLoss().cuda()
        
        ece_criterion = _ECELoss().cuda()

        # First: collect all the logits and labels for the validation set
        logits_list = []
        labels_list = []
        
        
        with torch.no_grad():
            
            for inputs, labels, length in valid_loader:
                
                inputs = inputs.cuda()
                length = length.cuda()
                labels = labels.cuda()
                
                labels = labels.long()
                
                ###
                seq_lengths, perm_idx = length.sort(0, descending=True)

                inputs = inputs[perm_idx]
                labels = labels[perm_idx]
                ###
                
                logits = self.model(inputs, seq_lengths)
                
                logits_list.append(logits)
                labels_list.append(labels)
                
            logit = torch.cat(logits_list).cuda()
            label = torch.cat(labels_list).cuda()

        # Calculate NLL and ECE before temperature scaling
        before_temperature_nll = nll_criterion(logit, label).item()
        before_temperature_ece = ece_criterion(logit, label).item()
        
        # print('Before temperature - NLL: %.3f, ECE: %.3f' % (before_temperature_nll, before_temperature_ece))

        # Next: optimize the temperature w.r.t. NLL
        optimizer = optim.Adam([self.temperature], lr=0.001)

        
        def eval():
            
            loss = nll_criterion(self.temperature_scale(logit), label)
            
            loss.backward()
        
            return loss
        
        optimizer.step(eval)

        
        # Calculate NLL and ECE after temperature scaling
        after_temperature_nll = nll_criterion(self.temperature_scale(logit), label).item()
        after_temperature_ece = ece_criterion(self.temperature_scale(logit), label).item()
        
        #print('After temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
        
        #print('Optimal temperature: %.3f' % self.temperature.item())

        return self
    
    
    
    def train_temprature(self, valid_loader):
        
        self.cuda()
        
        nll_criterion = nn.CrossEntropyLoss().cuda()
        
        ece_criterion = _ECELoss().cuda()
                
        optimizer = optim.Adam([self.temperature], lr=0.001)
        
        
        for epoch in range(25):
            
            logits_list = []
            labels_list = []

            self.train()
            
            for inputs, labels, length in valid_loader:
                
                inputs = inputs.cuda()
                length = length.cuda()
                labels = labels.cuda()
                
                labels = labels.long()
                
                ###
                seq_lengths, perm_idx = length.sort(0, descending=True)

                inputs = inputs[perm_idx]
                labels = labels[perm_idx]
                ###
                
                logits = self.model(inputs, seq_lengths)
                
                logits_list.append(logits)
                labels_list.append(labels)
                
            logit = torch.cat(logits_list).cuda()
            label = torch.cat(labels_list).cuda()
            
            loss = nll_criterion(self.temperature_scale(logit), label)
            
            optimizer.zero_grad()
            
            loss.backward()
            
            optimizer.step()
            
            after_temperature_nll = nll_criterion(self.temperature_scale(logit), label).item()
            after_temperature_ece = ece_criterion(self.temperature_scale(logit), label).item()

            #print("Epoch: {}/{} ".format(epoch+1, num_epochs), 'Temperature - NLL: %.3f, ECE: %.3f' % (after_temperature_nll, after_temperature_ece))
            
        return self

In [None]:
class _ECELoss(nn.Module):
    """
    Calculates the Expected Calibration Error of a model.
    (This isn't necessary for temperature scaling, just a cool metric).

    The input to this loss is the logits of a model, NOT the softmax scores.

    This divides the confidence outputs into equally-sized interval bins.
    In each bin, we compute the confidence gap:

    bin_gap = | avg_confidence_in_bin - accuracy_in_bin |

    We then return a weighted average of the gaps, based on the number
    of samples in each bin

    See: Naeini, Mahdi Pakdaman, Gregory F. Cooper, and Milos Hauskrecht.
    "Obtaining Well Calibrated Probabilities Using Bayesian Binning." AAAI.
    2015.
    """
    
    def __init__(self, n_bins=10):
        """
        n_bins (int): number of confidence interval bins
        """
        
        super(_ECELoss, self).__init__()
        
        bin_boundaries  = torch.linspace(0, 1, n_bins + 1)
        
        self.bin_lowers = bin_boundaries[:-1]
        
        self.bin_uppers = bin_boundaries[1:]

        
    def forward(self, logits, labels):
        
        softmaxes = F.softmax(logits, dim=1)
        
        confidences, predictions = torch.max(softmaxes, 1)
        
        accuracies = predictions.eq(labels)

        ece = torch.zeros(1, device=logits.device)
        
        for bin_lower, bin_upper in zip(self.bin_lowers, self.bin_uppers):
            
            # Calculated |confidence - accuracy| in each bin
            
            in_bin = confidences.gt(bin_lower.item()) * confidences.le(bin_upper.item())
            
            prop_in_bin = in_bin.float().mean()
            
            if prop_in_bin.item() > 0:
                
                accuracy_in_bin = accuracies[in_bin].float().mean()
                
                avg_confidence_in_bin = confidences[in_bin].mean()
                
                ece += torch.abs(avg_confidence_in_bin - accuracy_in_bin) * prop_in_bin

        return ece

# Early Stopping

In [None]:
class EarlyStopping:
    def __init__(self, patience=10, min_delta=0, verbose=False):
        self.patience = patience
        self.min_delta = min_delta
        self.verbose = verbose
        self.counter = 0
        self.best_score = None
        self.early_stop = False
        self.val_loss_min = np.Inf

    def __call__(self, val_loss, model):
        score = -val_loss

        if self.best_score is None:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
        elif score < self.best_score + self.min_delta:
            self.counter += 1
            if self.verbose:
                print(f'EarlyStopping counter: {self.counter} out of {self.patience}')
            if self.counter >= self.patience:
                self.early_stop = True
        else:
            self.best_score = score
            self.save_checkpoint(val_loss, model)
            self.counter = 0

    def save_checkpoint(self, val_loss, model):
        if self.verbose:
            print(f'Validation loss decreased ({self.val_loss_min:.6f} --> {val_loss:.6f}). Saving model ...')
        torch.save(model.state_dict(), 'checkpoint.pt')
        self.val_loss_min = val_loss

# Cross Validation

In [None]:
# Create indexing for real data only
real_data = merged_data[merged_data['data_type'] == 'real']
fake_data = merged_data[merged_data['data_type'] == 'fake']

# Create a new column that combines 'label' and 'ethnicity' for stratification
indexing_real = real_data[['patient_id', 'label', 'ethnicity']].groupby('patient_id').head(1)
indexing_real['stratify_col'] = indexing_real['label'].astype(str) + '_' + indexing_real['ethnicity'].astype(int).astype(str)

indexing_fake = fake_data[['patient_id', 'label', 'ethnicity']].groupby('patient_id').head(1)

## LSTM

In [None]:
torch.manual_seed(42)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(42)

In [None]:
Overall_ACC_ls_eth1 = []
Overall_ACC_ls_eth0 = []
Overall_ACC_ls = []
Overall_PRC_ls_eth1 = []
Overall_PRC_ls_eth0 = []
Overall_PRC_ls = []
Overall_REC_ls_eth1 = []
Overall_REC_ls_eth0 = []
Overall_REC_ls = []
Overall_F1M_ls_eth1 = []
Overall_F1M_ls_eth0 = []
Overall_F1M_ls = []
Overall_PPV_ls_eth1 = []
Overall_PPV_ls_eth0 = []
Overall_PPV_ls = []
Overall_NPV_ls_eth1 = []
Overall_NPV_ls_eth0 = []
Overall_NPV_ls = []
Overall_SEN_ls_eth1 = []
Overall_SEN_ls_eth0 = []
Overall_SEN_ls = []
Overall_SPE_ls_eth1 = []
Overall_SPE_ls_eth0 = []
Overall_SPE_ls = []
Overall_MCC_ls_eth1 = []
Overall_MCC_ls_eth0 = []
Overall_MCC_ls = []
Overall_AUC_ls_eth1 = []
Overall_AUC_ls_eth0 = []
Overall_AUC_ls = []
Overall_AP_ls_eth1 = []
Overall_AP_ls_eth0 = []
Overall_AP_ls = []
No_Skill_ls_eth1 = []
No_Skill_ls_eth0 = []
No_Skill_ls = []
tprs_ls_eth1 = []
tprs_ls_eth0 = []
tprs_ls = []
aucs_ls_eth1 = []
aucs_ls_eth0 = []
aucs_ls = []
mean_fpr_ls_eth1 = np.linspace(0, 1, 100)
mean_fpr_ls_eth0 = np.linspace(0, 1, 100)
mean_fpr_ls = np.linspace(0, 1, 100)
prs_ls_eth1 = []
prs_ls_eth0 = []
prs_ls = []
ap_ls_eth1 = []
ap_ls_eth0 = []
ap_ls = []
mean_recall_ls_eth1 = np.linspace(0, 1, 100)
mean_recall_ls_eth0 = np.linspace(0, 1, 100)
mean_recall_ls = np.linspace(0, 1, 100)
cal_prob_ls_eth1 = []
cal_prob_ls_eth0 = []
cal_prob_ls = []
cal_label_ls_eth1 = []
cal_label_ls_eth0 = []
cal_label_ls = []

Total_attribute = []

Total_features = []

In [None]:
import warnings
warnings.filterwarnings('ignore')

for i in range(10):

    skf = StratifiedKFold(n_splits=3, random_state=i*5, shuffle=True)

    for train_index, test_index in skf.split(indexing_real.patient_id, indexing_real.stratify_col):


        train = indexing_real.iloc[train_index]

        test  = indexing_real.iloc[test_index]
        
        
        train, valid = train_test_split(train, stratify= train.label, test_size= 0.20, random_state= 42)


        lactate_train_id = train.patient_id.unique()
        
        lactate_valid_id = valid.patient_id.unique()

        lactate_test_id  = test.patient_id.unique()


        train_df = merged_data[merged_data['patient_id'].isin(lactate_train_id)]
        
        valid_df = merged_data[merged_data['patient_id'].isin(lactate_valid_id)]

        test_df  = merged_data[merged_data['patient_id'].isin(lactate_test_id)]


        train_df_normilized, valid_df_normilized, test_df_normalized = Normalizing(train_df, valid_df, test_df)


        X_train, y_train, ethnicity_train = Data_Prepare(train_df_normilized)
        
        X_valid, y_valid, ethnicity_valid = Data_Prepare(valid_df_normilized)

        X_test , y_test, ethnicity_test  = Data_Prepare(test_df_normalized)


        X_data_train = np.array(X_train)
        
        X_data_valid = np.array(X_valid)
        
        X_data_test  = np.array(X_test)
        
        
        y_train = np.array(y_train)
        
        y_valid = np.array(y_valid)
        
        y_test  = np.array(y_test)


        ethnicity_test = np.array(ethnicity_test)


        X_len_train = np.full(len(y_train), 14)
        X_len_valid = np.full(len(y_valid), 14)
        X_len_test = np.full(len(y_test), 14)


        train_data = LactateData(X_data_train, y_train, X_len_train)
        valid_data = LactateData(X_data_valid, y_valid, X_len_valid)
        test_data  = LactateData(X_data_test , y_test , X_len_test)
        
        
        train_class_sample_count = torch.tensor([(torch.tensor(y_train) == t).sum() for t in torch.unique(torch.tensor(y_train), sorted=True)])
        train_weight  = 1 / train_class_sample_count.float()
        train_samples_weight = torch.tensor([train_weight[i] for i in torch.tensor(y_train).long()])
        train_sampler = WeightedRandomSampler(train_samples_weight, len(train_samples_weight))

        
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, sampler=train_sampler, **kwargs)
        valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, **kwargs)
        test_loader  = DataLoader(test_data , batch_size=batch_size, shuffle=False, **kwargs)
        
        
        model = build_model(input_size, hidden_size, number_layers, number_classes)
        
        model.apply(weight_init)
        
        
        class_weight = torch.Tensor([0.65, 0.35])
        class_weight = class_weight.to(device)
        criterion    = nn.CrossEntropyLoss(weight= class_weight)
        
        l1_crit   = nn.L1Loss(size_average=False)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        
        ##########################################################################
        
        # Initialize early stopping
        early_stopping = EarlyStopping(patience=10, verbose=False)

        train_losses = []
        valid_losses = []  # Added to track validation losses

        for epoch in range(num_epochs):
            train_loss = 0
            model.train()

            for x, y, l in train_loader:
                x = x.to(device)
                y = y.to(device)
                l = l.to(device)
                y = y.long()

                seq_lengths, perm_idx = l.sort(0, descending=True)
                x = x[perm_idx]
                y = y[perm_idx]

                outputs = model(x, seq_lengths)
                entropy_loss = criterion(outputs, y)

                l1_loss_ = 0
                for param in model.lstm.parameters():
                    l1_loss_ += l1_crit(param, target=torch.zeros_like(param))
                factor_1 = 0.0006
                l1_loss = factor_1 * l1_loss_

                loss = entropy_loss + l1_loss 

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()   

                train_loss += loss.item()   
            
            train_losses.append(train_loss/len(train_loader))

            # Validation step
            model.eval()
            valid_loss = 0
            with torch.no_grad():
                for x, y, l in valid_loader:
                    x = x.to(device)
                    y = y.to(device)
                    l = l.to(device)
                    y = y.long()

                    seq_lengths, perm_idx = l.sort(0, descending=True)
                    x = x[perm_idx]
                    y = y[perm_idx]

                    outputs = model(x, seq_lengths)
                    loss = criterion(outputs, y)  # Using only entropy loss for validation
                    valid_loss += loss.item()

            valid_loss = valid_loss / len(valid_loader)
            valid_losses.append(valid_loss)

            #print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_losses[-1]:.4f}, Valid Loss: {valid_loss:.4f}')

            # Early stopping
            early_stopping(valid_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

        print(f"TRAIN ENDED (epochs = {epoch+1})")
        ##########################################################################
        
        temp_model = ModelWithTemperature(model)
        
        temp_model.set_temperature(valid_loader)
        
        temp_model.train_temprature(valid_loader)    
       
        ##########################################################################
        
        test_losses = []
            
        temp_model.eval()

        with torch.no_grad():

            total = 0
            
            test_loss = 0

            test_labels = []

            test_probs  = []

            for x, y, l in test_loader:

                x = x.to(device)
                y = y.to(device)
                l = l.to(device)

                y = y.long()


                ###
                seq_lengths, perm_idx = l.sort(0, descending=True)

                x = x[perm_idx]
                y = y[perm_idx]
                ###


                output = temp_model(x, seq_lengths)
                
                entropy_loss = criterion(output, y)
                
                
                l1_loss_ = 0
                
                for param in model.lstm.parameters():
            
                    l1_loss_ += l1_crit(param, target=torch.zeros_like(param))
            
                factor = 0.0006
        
                l1_loss = factor * l1_loss_

                                
                loss = entropy_loss + l1_loss 
                
                test_loss += loss.item()
                

                outputs = nn.Softmax()(output)

                prediction = outputs.detach().cpu().numpy()

                prediction = prediction[:,1]

                test_labels.append(y.detach().cpu().numpy())

                test_probs.append(prediction)
                
                
            test_losses.append(test_loss/len(test_loader))
        
            #print("Test loss: {:0.4f} ".format(test_loss/len(test_loader)))
        
            #print("TEST ENDED")
            
            
            test_labels = [l for labels in test_labels for l in labels]
        
            test_probs  = [p for probs  in test_probs  for p in probs]
                        
        ##########################################################################
            

        pred_probs_eth0 = np.array(test_probs)[ethnicity_test == 0]
        true_labels_eth0 = np.array(test_labels)[ethnicity_test == 0]

        pred_probs_eth1 = np.array(test_probs)[ethnicity_test == 1]
        true_labels_eth1 = np.array(test_labels)[ethnicity_test == 1]

        # For overall (without ethnicity specification)
        cal_label_ls.append(test_labels)
        cal_prob_ls.append(test_probs)

        TN, FP, FN, TP = confusion_matrix(test_labels, np.array(test_probs).round()).ravel()

        ACC = (TP+TN)/(TP+FP+FN+TN)
        PRC = (TP)/(TP+FP)
        REC = (TP)/(TP+FN)
        F1M = (2*PRC*REC)/(PRC+REC)
        PPV = TP/(TP+FP)
        NPV = TN/(TN+FN)
        SEN = (TP)/(TP+FN)
        SPE = (TN)/(TN+FP)
        MCC = matthews_corrcoef(test_labels, np.array(test_probs).round())
        AUC = roc_auc_score(test_labels, test_probs)
        AP = average_precision_score(test_labels, test_probs)

        Overall_ACC_ls.append(ACC)
        Overall_PRC_ls.append(PRC)
        Overall_REC_ls.append(REC)
        Overall_F1M_ls.append(F1M)
        Overall_PPV_ls.append(PPV)
        Overall_NPV_ls.append(NPV)
        Overall_SEN_ls.append(SEN)
        Overall_SPE_ls.append(SPE)
        Overall_MCC_ls.append(MCC)
        Overall_AUC_ls.append(AUC)
        Overall_AP_ls.append(AP)

        fpr_lstm, tpr_lstm, thresholds_lstm = roc_curve(test_labels, test_probs)
        tprs_ls.append(interp(mean_fpr_ls, fpr_lstm, tpr_lstm))
        tprs_ls[-1][0] = 0.0
        roc_auc = auc(fpr_lstm, tpr_lstm)
        aucs_ls.append(roc_auc)

        no_skill = len([lab for lab in test_labels if lab == 1]) / len(test_labels)
        No_Skill_ls.append(no_skill)

        precision_lstm, recall_lstm, threshold_lstm = precision_recall_curve(test_labels, test_probs)
        prs_ls.append(interp(mean_recall_ls, precision_lstm, recall_lstm))
        pr_ap = auc(recall_lstm, precision_lstm)
        ap_ls.append(pr_ap)


        # For eth1
        cal_label_ls_eth1.append(true_labels_eth1)
        cal_prob_ls_eth1.append(pred_probs_eth1)

        TN_eth1, FP_eth1, FN_eth1, TP_eth1 = confusion_matrix(true_labels_eth1, pred_probs_eth1.round()).ravel()

        ACC_eth1 = (TP_eth1+TN_eth1)/(TP_eth1+FP_eth1+FN_eth1+TN_eth1)
        PRC_eth1 = (TP_eth1)/(TP_eth1+FP_eth1)
        REC_eth1 = (TP_eth1)/(TP_eth1+FN_eth1)
        F1M_eth1 = (2*PRC_eth1*REC_eth1)/(PRC_eth1+REC_eth1)
        PPV_eth1 = TP_eth1/(TP_eth1+FP_eth1)
        NPV_eth1 = TN_eth1/(TN_eth1+FN_eth1)
        SEN_eth1 = (TP_eth1)/(TP_eth1+FN_eth1)
        SPE_eth1 = (TN_eth1)/(TN_eth1+FP_eth1)
        MCC_eth1 = matthews_corrcoef(true_labels_eth1, pred_probs_eth1.round())
        AUC_eth1 = roc_auc_score(true_labels_eth1, pred_probs_eth1)
        AP_eth1 = average_precision_score(true_labels_eth1, pred_probs_eth1)

        Overall_ACC_ls_eth1.append(ACC_eth1)
        Overall_PRC_ls_eth1.append(PRC_eth1)
        Overall_REC_ls_eth1.append(REC_eth1)
        Overall_F1M_ls_eth1.append(F1M_eth1)
        Overall_PPV_ls_eth1.append(PPV_eth1)
        Overall_NPV_ls_eth1.append(NPV_eth1)
        Overall_SEN_ls_eth1.append(SEN_eth1)
        Overall_SPE_ls_eth1.append(SPE_eth1)
        Overall_MCC_ls_eth1.append(MCC_eth1)
        Overall_AUC_ls_eth1.append(AUC_eth1)
        Overall_AP_ls_eth1.append(AP_eth1)

        fpr_lstm_eth1, tpr_lstm_eth1, thresholds_lstm_eth1 = roc_curve(true_labels_eth1, pred_probs_eth1)
        tprs_ls_eth1.append(interp(mean_fpr_ls_eth1, fpr_lstm_eth1, tpr_lstm_eth1))
        tprs_ls_eth1[-1][0] = 0.0
        roc_auc_eth1 = auc(fpr_lstm_eth1, tpr_lstm_eth1)
        aucs_ls_eth1.append(roc_auc_eth1)

        no_skill_eth1 = len([lab for lab in true_labels_eth1 if lab == 1]) / len(true_labels_eth1)
        No_Skill_ls_eth1.append(no_skill_eth1)

        precision_lstm_eth1, recall_lstm_eth1, threshold_lstm_eth1 = precision_recall_curve(true_labels_eth1, pred_probs_eth1)
        prs_ls_eth1.append(interp(mean_recall_ls_eth1, precision_lstm_eth1, recall_lstm_eth1))
        pr_ap_eth1 = auc(recall_lstm_eth1, precision_lstm_eth1)
        ap_ls_eth1.append(pr_ap_eth1)

        # For eth0
        cal_label_ls_eth0.append(true_labels_eth0)
        cal_prob_ls_eth0.append(pred_probs_eth0)

        TN_eth0, FP_eth0, FN_eth0, TP_eth0 = confusion_matrix(true_labels_eth0, pred_probs_eth0.round()).ravel()

        ACC_eth0 = (TP_eth0+TN_eth0)/(TP_eth0+FP_eth0+FN_eth0+TN_eth0)
        PRC_eth0 = (TP_eth0)/(TP_eth0+FP_eth0)
        REC_eth0 = (TP_eth0)/(TP_eth0+FN_eth0)
        F1M_eth0 = (2*PRC_eth0*REC_eth0)/(PRC_eth0+REC_eth0)
        PPV_eth0 = TP_eth0/(TP_eth0+FP_eth0)
        NPV_eth0 = TN_eth0/(TN_eth0+FN_eth0)
        SEN_eth0 = (TP_eth0)/(TP_eth0+FN_eth0)
        SPE_eth0 = (TN_eth0)/(TN_eth0+FP_eth0)
        MCC_eth0 = matthews_corrcoef(true_labels_eth0, pred_probs_eth0.round())
        AUC_eth0 = roc_auc_score(true_labels_eth0, pred_probs_eth0)
        AP_eth0 = average_precision_score(true_labels_eth0, pred_probs_eth0)

        Overall_ACC_ls_eth0.append(ACC_eth0)
        Overall_PRC_ls_eth0.append(PRC_eth0)
        Overall_REC_ls_eth0.append(REC_eth0)
        Overall_F1M_ls_eth0.append(F1M_eth0)
        Overall_PPV_ls_eth0.append(PPV_eth0)
        Overall_NPV_ls_eth0.append(NPV_eth0)
        Overall_SEN_ls_eth0.append(SEN_eth0)
        Overall_SPE_ls_eth0.append(SPE_eth0)
        Overall_MCC_ls_eth0.append(MCC_eth0)
        Overall_AUC_ls_eth0.append(AUC_eth0)
        Overall_AP_ls_eth0.append(AP_eth0)

        fpr_lstm_eth0, tpr_lstm_eth0, thresholds_lstm_eth0 = roc_curve(true_labels_eth0, pred_probs_eth0)
        tprs_ls_eth0.append(interp(mean_fpr_ls_eth0, fpr_lstm_eth0, tpr_lstm_eth0))
        tprs_ls_eth0[-1][0] = 0.0
        roc_auc_eth0 = auc(fpr_lstm_eth0, tpr_lstm_eth0)
        aucs_ls_eth0.append(roc_auc_eth0)

        no_skill_eth0 = len([lab for lab in true_labels_eth0 if lab == 1]) / len(true_labels_eth0)
        No_Skill_ls_eth0.append(no_skill_eth0)

        precision_lstm_eth0, recall_lstm_eth0, threshold_lstm_eth0 = precision_recall_curve(true_labels_eth0, pred_probs_eth0)
        prs_ls_eth0.append(interp(mean_recall_ls_eth0, precision_lstm_eth0, recall_lstm_eth0))
        pr_ap_eth0 = auc(recall_lstm_eth0, precision_lstm_eth0)
        ap_ls_eth0.append(pr_ap_eth0)
        
        ##########################################################################
            
        IG = IntegratedGradients(temp_model)
        
        temp_model.train()
    
        for x, y, l in test_loader:

            x = x.to(device)
            y = y.to(device)
            l = l.to(device)

            y = y.long()


            ###
            seq_lengths, perm_idx = l.sort(0, descending=True)

            x = x[perm_idx]
            y = y[perm_idx]
            ###


            attribute = IG.attribute(x, additional_forward_args=seq_lengths, target=0)

            Total_attribute.append(attribute.detach().cpu().numpy())
            
        
        Total_features.append(X_data_test)

In [None]:
print("Ethnicity: 0")

print(f"Mean ACC: {np.mean(Overall_ACC_ls_eth0):.4f}")
print(f"Mean PRC: {np.mean(Overall_PRC_ls_eth0):.4f}")
print(f"Mean REC: {np.mean(Overall_REC_ls_eth0):.4f}")
print(f"Mean F1M: {np.mean(Overall_F1M_ls_eth0):.4f}")
print(f"Mean PPV: {np.mean(Overall_PPV_ls_eth0):.4f}")
print(f"Mean NPV: {np.mean(Overall_NPV_ls_eth0):.4f}")
print(f"Mean SEN: {np.mean(Overall_SEN_ls_eth0):.4f}")
print(f"Mean SPE: {np.mean(Overall_SPE_ls_eth0):.4f}")
print(f"Mean MCC: {np.mean(Overall_MCC_ls_eth0):.4f}")
print(f"Mean AUC: {np.mean(Overall_AUC_ls_eth0):.4f}")
print(f"Mean AP: {np.mean(Overall_AP_ls_eth0):.4f}")

print("\n")

print("Ethnicity: 1")

print(f"Mean ACC: {np.mean(Overall_ACC_ls_eth1):.4f}")
print(f"Mean PRC: {np.mean(Overall_PRC_ls_eth1):.4f}")
print(f"Mean REC: {np.mean(Overall_REC_ls_eth1):.4f}")
print(f"Mean F1M: {np.mean(Overall_F1M_ls_eth1):.4f}")
print(f"Mean PPV: {np.mean(Overall_PPV_ls_eth1):.4f}")
print(f"Mean NPV: {np.mean(Overall_NPV_ls_eth1):.4f}")
print(f"Mean SEN: {np.mean(Overall_SEN_ls_eth1):.4f}")
print(f"Mean SPE: {np.mean(Overall_SPE_ls_eth1):.4f}")
print(f"Mean MCC: {np.mean(Overall_MCC_ls_eth1):.4f}")
print(f"Mean AUC: {np.mean(Overall_AUC_ls_eth1):.4f}")
print(f"Mean AP: {np.mean(Overall_AP_ls_eth1):.4f}")

In [None]:
print("Ethnicity: 0")

print(f"Std ACC: {np.std(Overall_ACC_ls_eth0):.4f}")
print(f"Std PRC: {np.std(Overall_PRC_ls_eth0):.4f}")
print(f"Std REC: {np.std(Overall_REC_ls_eth0):.4f}")
print(f"Std F1M: {np.std(Overall_F1M_ls_eth0):.4f}")
print(f"Std PPV: {np.std(Overall_PPV_ls_eth0):.4f}")
print(f"Std NPV: {np.std(Overall_NPV_ls_eth0):.4f}")
print(f"Std SEN: {np.std(Overall_SEN_ls_eth0):.4f}")
print(f"Std SPE: {np.std(Overall_SPE_ls_eth0):.4f}")
print(f"Std MCC: {np.std(Overall_MCC_ls_eth0):.4f}")
print(f"Std AUC: {np.std(Overall_AUC_ls_eth0):.4f}")
print(f"Std AP: {np.std(Overall_AP_ls_eth0):.4f}")

print("\n")  # Add a blank line for separation

print("Ethnicity: 1")

print(f"Std ACC: {np.std(Overall_ACC_ls_eth1):.4f}")
print(f"Std PRC: {np.std(Overall_PRC_ls_eth1):.4f}")
print(f"Std REC: {np.std(Overall_REC_ls_eth1):.4f}")
print(f"Std F1M: {np.std(Overall_F1M_ls_eth1):.4f}")
print(f"Std PPV: {np.std(Overall_PPV_ls_eth1):.4f}")
print(f"Std NPV: {np.std(Overall_NPV_ls_eth1):.4f}")
print(f"Std SEN: {np.std(Overall_SEN_ls_eth1):.4f}")
print(f"Std SPE: {np.std(Overall_SPE_ls_eth1):.4f}")
print(f"Std MCC: {np.std(Overall_MCC_ls_eth1):.4f}")
print(f"Std AUC: {np.std(Overall_AUC_ls_eth1):.4f}")
print(f"Std AP: {np.std(Overall_AP_ls_eth1):.4f}")

In [None]:
metrics = ['ACC', 'PRC', 'REC', 'F1M', 'PPV', 'NPV', 'SEN', 'SPE', 'MCC', 'AUC', 'AP']

print("Results for original data:\n")

print(f"{'Metric':<7}{'Overall (Mean ± Std)':<30}{'Ethnicity 0 (Mean ± Std)':<30}{'Ethnicity 1 (Mean ± Std)':<30}")

print('-' * 96)


for metric in metrics:
    eth0_mean = np.nanmean(globals()[f'Overall_{metric}_ls_eth0'])
    eth0_std = np.nanstd(globals()[f'Overall_{metric}_ls_eth0'])
    eth1_mean = np.nanmean(globals()[f'Overall_{metric}_ls_eth1'])
    eth1_std = np.nanstd(globals()[f'Overall_{metric}_ls_eth1'])

    full_mean = np.nanmean(globals()[f'Overall_{metric}_ls'])
    full_std = np.nanstd(globals()[f'Overall_{metric}_ls'])
    
    print(f"{metric:<7}{full_mean:.4f} ± {full_std:<21.4f}{eth0_mean:.4f} ± {eth0_std:<21.4f}{eth1_mean:.4f} ± {eth1_std:.4f}")


In [None]:
import numpy as np
from scipy import stats

metrics = ['ACC', 'PRC', 'REC', 'F1M', 'PPV', 'NPV', 'SEN', 'SPE', 'MCC', 'AUC', 'AP']

print("Results for original data:\n")

print(f"{'Metric':<7}{'Full data (Mean [95% CI])':<35}{'Ethnicity 0 (Mean [95% CI])':<35}{'Ethnicity 1 (Mean [95% CI])':<35}")

print('-' * 77)

for metric in metrics:
    eth0_data = globals()[f'Overall_{metric}_ls_eth0']
    eth1_data = globals()[f'Overall_{metric}_ls_eth1']
    full_data = globals()[f'Overall_{metric}_ls']
    
    eth0_mean = np.nanmean(eth0_data)
    eth1_mean = np.nanmean(eth1_data)
    full_mean = np.nanmean(full_data)
    
    eth0_sem = stats.sem(eth0_data, nan_policy='omit')
    eth1_sem = stats.sem(eth1_data, nan_policy='omit')
    full_sem = stats.sem(full_data, nan_policy='omit')
    
    eth0_ci = eth0_mean - 1.96*eth0_sem, eth0_mean + 1.96*eth0_sem
    eth1_ci = eth1_mean - 1.96*eth1_sem, eth1_mean + 1.96*eth1_sem
    full_ci = full_mean - 1.96*full_sem, full_mean + 1.96*full_sem
    
    print(f"{metric:<7}{full_mean:.4f} [{full_ci[0]:.4f}, {full_ci[1]:.4f}]{' '*5}{eth0_mean:.4f} [{eth0_ci[0]:.4f}, {eth0_ci[1]:.4f}]{' '*5}{eth1_mean:.4f} [{eth1_ci[0]:.4f}, {eth1_ci[1]:.4f}]")

In [None]:
plt.figure(figsize=(12, 6))

# Plot for Ethnicity 0
plt.subplot(1, 2, 1)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='green', label='Random', alpha=.8)

mean_tpr_ls_eth0 = np.mean(tprs_ls_eth0, axis=0)
mean_tpr_ls_eth0[-1] = 1.0
mean_auc_ls_eth0 = auc(mean_fpr_ls_eth0, mean_tpr_ls_eth0)
std_auc_ls_eth0 = np.std(aucs_ls_eth0)
plt.plot(mean_fpr_ls_eth0, mean_tpr_ls_eth0, color='r', 
         label=r'LSTM Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc_ls_eth0, std_auc_ls_eth0), 
         lw=2, alpha=.8)

std_tpr_ls_eth0 = np.std(tprs_ls_eth0, axis=0)
tprs_upper_ls_eth0 = np.minimum(mean_tpr_ls_eth0 + std_tpr_ls_eth0, 1)
tprs_lower_ls_eth0 = np.maximum(mean_tpr_ls_eth0 - std_tpr_ls_eth0, 0)
plt.fill_between(mean_fpr_ls_eth0, tprs_lower_ls_eth0, tprs_upper_ls_eth0, color='grey', alpha=.2)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Ethnicity 0')
plt.legend(loc="lower right")

# Plot for Ethnicity 1
plt.subplot(1, 2, 2)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='green', label='Random', alpha=.8)

mean_tpr_ls_eth1 = np.mean(tprs_ls_eth1, axis=0)
mean_tpr_ls_eth1[-1] = 1.0
mean_auc_ls_eth1 = auc(mean_fpr_ls_eth1, mean_tpr_ls_eth1)
std_auc_ls_eth1 = np.std(aucs_ls_eth1)
plt.plot(mean_fpr_ls_eth1, mean_tpr_ls_eth1, color='b', 
         label=r'LSTM Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc_ls_eth1, std_auc_ls_eth1), 
         lw=2, alpha=.8)

std_tpr_ls_eth1 = np.std(tprs_ls_eth1, axis=0)
tprs_upper_ls_eth1 = np.minimum(mean_tpr_ls_eth1 + std_tpr_ls_eth1, 1)
tprs_lower_ls_eth1 = np.maximum(mean_tpr_ls_eth1 - std_tpr_ls_eth1, 0)
plt.fill_between(mean_fpr_ls_eth1, tprs_lower_ls_eth1, tprs_upper_ls_eth1, color='grey', alpha=.2)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Ethnicity 1')
plt.legend(loc="lower right")

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(12, 6))

# Plot for Ethnicity 0
plt.subplot(1, 2, 1)
plt.plot([0, 1], [np.mean(No_Skill_ls_eth0), np.mean(No_Skill_ls_eth0)], linestyle='--', lw=2, color='green', label='Random', alpha=.8)

mean_prs_ls_eth0 = np.mean(prs_ls_eth0, axis=0)
mean_ap_ls_eth0 = auc(mean_recall_ls_eth0, mean_prs_ls_eth0)
std_ap_ls_eth0 = np.std(ap_ls_eth0)
plt.plot(mean_recall_ls_eth0, mean_prs_ls_eth0, color='r', 
         label=r'Mean (AP = %0.3f $\pm$ %0.2f)' % (mean_ap_ls_eth0, std_ap_ls_eth0), 
         lw=2, alpha=.8)

std_prs_ls_eth0 = np.std(prs_ls_eth0, axis=0)
prs_upper_ls_eth0 = np.minimum(mean_prs_ls_eth0 + std_prs_ls_eth0, 1)
prs_lower_ls_eth0 = np.maximum(mean_prs_ls_eth0 - std_prs_ls_eth0, 0)
plt.fill_between(mean_recall_ls_eth0, prs_lower_ls_eth0, prs_upper_ls_eth0, color='grey', alpha=.2)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve - Ethnicity 0')
plt.legend(loc='upper right')

# Plot for Ethnicity 1
plt.subplot(1, 2, 2)
plt.plot([0, 1], [np.mean(No_Skill_ls_eth1), np.mean(No_Skill_ls_eth1)], linestyle='--', lw=2, color='green', label='Random', alpha=.8)

mean_prs_ls_eth1 = np.mean(prs_ls_eth1, axis=0)
mean_ap_ls_eth1 = auc(mean_recall_ls_eth1, mean_prs_ls_eth1)
std_ap_ls_eth1 = np.std(ap_ls_eth1)
plt.plot(mean_recall_ls_eth1, mean_prs_ls_eth1, color='b', 
         label=r'Mean (AP = %0.3f $\pm$ %0.2f)' % (mean_ap_ls_eth1, std_ap_ls_eth1), 
         lw=2, alpha=.8)

std_prs_ls_eth1 = np.std(prs_ls_eth1, axis=0)
prs_upper_ls_eth1 = np.minimum(mean_prs_ls_eth1 + std_prs_ls_eth1, 1)
prs_lower_ls_eth1 = np.maximum(mean_prs_ls_eth1 - std_prs_ls_eth1, 0)
plt.fill_between(mean_recall_ls_eth1, prs_lower_ls_eth1, prs_upper_ls_eth1, color='grey', alpha=.2)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve - Ethnicity 1')
plt.legend(loc='upper right')

plt.tight_layout()
plt.show()

In [None]:
results_real = {
    'Overall_ACC_ls_eth1': Overall_ACC_ls_eth1,
    'Overall_ACC_ls_eth0': Overall_ACC_ls_eth0,
    'Overall_ACC_ls': Overall_ACC_ls,
    'Overall_PRC_ls_eth1': Overall_PRC_ls_eth1,
    'Overall_PRC_ls_eth0': Overall_PRC_ls_eth0,
    'Overall_PRC_ls': Overall_PRC_ls,
    'Overall_REC_ls_eth1': Overall_REC_ls_eth1,
    'Overall_REC_ls_eth0': Overall_REC_ls_eth0,
    'Overall_REC_ls': Overall_REC_ls,
    'Overall_F1M_ls_eth1': Overall_F1M_ls_eth1,
    'Overall_F1M_ls_eth0': Overall_F1M_ls_eth0,
    'Overall_F1M_ls': Overall_F1M_ls,
    'Overall_PPV_ls_eth1': Overall_PPV_ls_eth1,
    'Overall_PPV_ls_eth0': Overall_PPV_ls_eth0,
    'Overall_PPV_ls': Overall_PPV_ls,
    'Overall_NPV_ls_eth1': Overall_NPV_ls_eth1,
    'Overall_NPV_ls_eth0': Overall_NPV_ls_eth0,
    'Overall_NPV_ls': Overall_NPV_ls,
    'Overall_SEN_ls_eth1': Overall_SEN_ls_eth1,
    'Overall_SEN_ls_eth0': Overall_SEN_ls_eth0,
    'Overall_SEN_ls': Overall_SEN_ls,
    'Overall_SPE_ls_eth1': Overall_SPE_ls_eth1,
    'Overall_SPE_ls_eth0': Overall_SPE_ls_eth0,
    'Overall_SPE_ls': Overall_SPE_ls,
    'Overall_MCC_ls_eth1': Overall_MCC_ls_eth1,
    'Overall_MCC_ls_eth0': Overall_MCC_ls_eth0,
    'Overall_MCC_ls': Overall_MCC_ls,
    'Overall_AUC_ls_eth1': Overall_AUC_ls_eth1,
    'Overall_AUC_ls_eth0': Overall_AUC_ls_eth0,
    'Overall_AUC_ls': Overall_AUC_ls,
    'Overall_AP_ls_eth1': Overall_AP_ls_eth1,
    'Overall_AP_ls_eth0': Overall_AP_ls_eth0,
    'Overall_AP_ls': Overall_AP_ls,
    'No_Skill_ls_eth1': No_Skill_ls_eth1,
    'No_Skill_ls_eth0': No_Skill_ls_eth0,
    'No_Skill_ls': No_Skill_ls,
    'tprs_ls_eth1': tprs_ls_eth1,
    'tprs_ls_eth0': tprs_ls_eth0,
    'tprs_ls': tprs_ls,
    'aucs_ls_eth1': aucs_ls_eth1,
    'aucs_ls_eth0': aucs_ls_eth0,
    'aucs_ls': aucs_ls,
    'mean_fpr_ls_eth1': mean_fpr_ls_eth1,
    'mean_fpr_ls_eth0': mean_fpr_ls_eth0,
    'mean_fpr_ls': mean_fpr_ls,
    'prs_ls_eth1': prs_ls_eth1,
    'prs_ls_eth0': prs_ls_eth0,
    'prs_ls': prs_ls,
    'ap_ls_eth1': ap_ls_eth1,
    'ap_ls_eth0': ap_ls_eth0,
    'ap_ls': ap_ls,
    'mean_recall_ls_eth1': mean_recall_ls_eth1,
    'mean_recall_ls_eth0': mean_recall_ls_eth0,
    'mean_recall_ls': mean_recall_ls,
    'cal_prob_ls_eth1': cal_prob_ls_eth1,
    'cal_prob_ls_eth0': cal_prob_ls_eth0,
    'cal_prob_ls': cal_prob_ls,
    'cal_label_ls_eth1': cal_label_ls_eth1,
    'cal_label_ls_eth0': cal_label_ls_eth0,
    'cal_label_ls': cal_label_ls,
    'Total_attribute': Total_attribute,
    'Total_features': Total_features
}

## LSTM WITH BALANCED DATA

In [None]:
torch.manual_seed(123)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(123)

In [None]:
Overall_ACC_ls_eth1 = []
Overall_ACC_ls_eth0 = []
Overall_ACC_ls = []
Overall_PRC_ls_eth1 = []
Overall_PRC_ls_eth0 = []
Overall_PRC_ls = []
Overall_REC_ls_eth1 = []
Overall_REC_ls_eth0 = []
Overall_REC_ls = []
Overall_F1M_ls_eth1 = []
Overall_F1M_ls_eth0 = []
Overall_F1M_ls = []
Overall_PPV_ls_eth1 = []
Overall_PPV_ls_eth0 = []
Overall_PPV_ls = []
Overall_NPV_ls_eth1 = []
Overall_NPV_ls_eth0 = []
Overall_NPV_ls = []
Overall_SEN_ls_eth1 = []
Overall_SEN_ls_eth0 = []
Overall_SEN_ls = []
Overall_SPE_ls_eth1 = []
Overall_SPE_ls_eth0 = []
Overall_SPE_ls = []
Overall_MCC_ls_eth1 = []
Overall_MCC_ls_eth0 = []
Overall_MCC_ls = []
Overall_AUC_ls_eth1 = []
Overall_AUC_ls_eth0 = []
Overall_AUC_ls = []
Overall_AP_ls_eth1 = []
Overall_AP_ls_eth0 = []
Overall_AP_ls = []
No_Skill_ls_eth1 = []
No_Skill_ls_eth0 = []
No_Skill_ls = []
tprs_ls_eth1 = []
tprs_ls_eth0 = []
tprs_ls = []
aucs_ls_eth1 = []
aucs_ls_eth0 = []
aucs_ls = []
mean_fpr_ls_eth1 = np.linspace(0, 1, 100)
mean_fpr_ls_eth0 = np.linspace(0, 1, 100)
mean_fpr_ls = np.linspace(0, 1, 100)
prs_ls_eth1 = []
prs_ls_eth0 = []
prs_ls = []
ap_ls_eth1 = []
ap_ls_eth0 = []
ap_ls = []
mean_recall_ls_eth1 = np.linspace(0, 1, 100)
mean_recall_ls_eth0 = np.linspace(0, 1, 100)
mean_recall_ls = np.linspace(0, 1, 100)
cal_prob_ls_eth1 = []
cal_prob_ls_eth0 = []
cal_prob_ls = []
cal_label_ls_eth1 = []
cal_label_ls_eth0 = []
cal_label_ls = []
Total_attribute = []

Total_features = []

In [None]:
warnings.filterwarnings('ignore')

# Train with balanced data

for i in range(10):

    skf = StratifiedKFold(n_splits=3, random_state=i*5, shuffle=True)

    for train_index, test_index in skf.split(indexing_real.patient_id, indexing_real.stratify_col):


        train = indexing_real.iloc[train_index]

        test  = indexing_real.iloc[test_index]
        
        # Add fake data indeces
        # Count unique IDs for each ethnicity in the training set
        train_eth0_count = train[train['ethnicity'] == 0]['patient_id'].nunique()
        train_eth1_count = train[train['ethnicity'] == 1]['patient_id'].nunique()
        
        # Determine how many fake samples to add
        n_fake_to_add = train_eth0_count - train_eth1_count
           
        # If we have more fake samples than needed, randomly select subset
        if len(indexing_fake) > n_fake_to_add:
            fake_to_add = indexing_fake.sample(n=n_fake_to_add, random_state=42)
        else:
            fake_to_add = indexing_fake
        
        # Concatenate fake samples to the training set
        train = pd.concat([train, fake_to_add], ignore_index=True)


        train, valid = train_test_split(train, stratify= train.label, test_size= 0.20, random_state= 123)


        lactate_train_id = train.patient_id.unique()
        
        lactate_valid_id = valid.patient_id.unique()

        lactate_test_id  = test.patient_id.unique()


        train_df = merged_data[merged_data['patient_id'].isin(lactate_train_id)]
        
        valid_df = merged_data[merged_data['patient_id'].isin(lactate_valid_id)]

        test_df  = merged_data[merged_data['patient_id'].isin(lactate_test_id)]

        print("Ethnicity counts after balancing:")
        print(train_df.groupby('patient_id')['ethnicity'].first().value_counts())

        train_df_normilized, valid_df_normilized, test_df_normalized = Normalizing(train_df, valid_df, test_df)


        X_train, y_train, ethnicity_train = Data_Prepare(train_df_normilized)
        
        X_valid, y_valid, ethnicity_valid = Data_Prepare(valid_df_normilized)

        X_test , y_test, ethnicity_test  = Data_Prepare(test_df_normalized)


        X_data_train = np.array(X_train)
        
        X_data_valid = np.array(X_valid)
        
        X_data_test  = np.array(X_test)
        
        
        y_train = np.array(y_train)
        
        y_valid = np.array(y_valid)
        
        y_test  = np.array(y_test)


        ethnicity_test = np.array(ethnicity_test)


        X_len_train = np.full(len(y_train), 14)
        X_len_valid = np.full(len(y_valid), 14)
        X_len_test = np.full(len(y_test), 14)


        train_data = LactateData(X_data_train, y_train, X_len_train)
        valid_data = LactateData(X_data_valid, y_valid, X_len_valid)
        test_data  = LactateData(X_data_test , y_test , X_len_test)
        
        
        train_class_sample_count = torch.tensor([(torch.tensor(y_train) == t).sum() for t in torch.unique(torch.tensor(y_train), sorted=True)])
        train_weight  = 1 / train_class_sample_count.float()
        train_samples_weight = torch.tensor([train_weight[i] for i in torch.tensor(y_train).long()])
        train_sampler = WeightedRandomSampler(train_samples_weight, len(train_samples_weight))

        
        train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=False, sampler=train_sampler, **kwargs)
        valid_loader = DataLoader(valid_data, batch_size=batch_size, shuffle=False, **kwargs)
        test_loader  = DataLoader(test_data , batch_size=batch_size, shuffle=False, **kwargs)
        
        
        model = build_model(input_size, hidden_size, number_layers, number_classes)
        
        model.apply(weight_init)
        
        
        class_weight = torch.Tensor([0.65, 0.35])
        class_weight = class_weight.to(device)
        criterion    = nn.CrossEntropyLoss(weight= class_weight)
        
        l1_crit   = nn.L1Loss(size_average=False)
        optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
        
        ##########################################################################
        
        # Initialize early stopping
        early_stopping = EarlyStopping(patience=10, verbose=False)

        train_losses = []
        valid_losses = []  # Added to track validation losses

        for epoch in range(num_epochs):
            train_loss = 0
            model.train()

            for x, y, l in train_loader:
                x = x.to(device)
                y = y.to(device)
                l = l.to(device)
                y = y.long()

                seq_lengths, perm_idx = l.sort(0, descending=True)
                x = x[perm_idx]
                y = y[perm_idx]

                outputs = model(x, seq_lengths)
                entropy_loss = criterion(outputs, y)

                l1_loss_ = 0
                for param in model.lstm.parameters():
                    l1_loss_ += l1_crit(param, target=torch.zeros_like(param))
                factor_1 = 0.0006
                l1_loss = factor_1 * l1_loss_

                loss = entropy_loss + l1_loss 

                optimizer.zero_grad()
                loss.backward()
                optimizer.step()   

                train_loss += loss.item()   
            
            train_losses.append(train_loss/len(train_loader))

            # Validation step
            model.eval()
            valid_loss = 0
            with torch.no_grad():
                for x, y, l in valid_loader:
                    x = x.to(device)
                    y = y.to(device)
                    l = l.to(device)
                    y = y.long()

                    seq_lengths, perm_idx = l.sort(0, descending=True)
                    x = x[perm_idx]
                    y = y[perm_idx]

                    outputs = model(x, seq_lengths)
                    loss = criterion(outputs, y)  # Using only entropy loss for validation
                    valid_loss += loss.item()

            valid_loss = valid_loss / len(valid_loader)
            valid_losses.append(valid_loss)

            #print(f'Epoch {epoch+1}/{num_epochs}, Train Loss: {train_losses[-1]:.4f}, Valid Loss: {valid_loss:.4f}')

            # Early stopping
            early_stopping(valid_loss, model)
            if early_stopping.early_stop:
                print("Early stopping")
                break

        print(f"TRAIN ENDED (epochs = {epoch+1})")
            
        ##########################################################################
        
        temp_model = ModelWithTemperature(model)
        
        temp_model.set_temperature(valid_loader)
        
        temp_model.train_temprature(valid_loader)    
       
        ##########################################################################
        
        test_losses = []
            
        temp_model.eval()

        with torch.no_grad():

            total = 0
            
            test_loss = 0

            test_labels = []

            test_probs  = []

            for x, y, l in test_loader:

                x = x.to(device)
                y = y.to(device)
                l = l.to(device)

                y = y.long()


                ###
                seq_lengths, perm_idx = l.sort(0, descending=True)

                x = x[perm_idx]
                y = y[perm_idx]
                ###


                output = temp_model(x, seq_lengths)
                
                entropy_loss = criterion(output, y)
                
                
                l1_loss_ = 0
                
                for param in model.lstm.parameters():
            
                    l1_loss_ += l1_crit(param, target=torch.zeros_like(param))
            
                factor = 0.0006
        
                l1_loss = factor * l1_loss_

                                
                loss = entropy_loss + l1_loss 
                
                test_loss += loss.item()
                

                outputs = nn.Softmax()(output)

                prediction = outputs.detach().cpu().numpy()

                prediction = prediction[:,1]

                test_labels.append(y.detach().cpu().numpy())

                test_probs.append(prediction)
                
                
            test_losses.append(test_loss/len(test_loader))
        
            #print("Test loss: {:0.4f} ".format(test_loss/len(test_loader)))
        
            #print("TEST ENDED")
            
            
            test_labels = [l for labels in test_labels for l in labels]
        
            test_probs  = [p for probs  in test_probs  for p in probs]
                        
        ##########################################################################
            

        pred_probs_eth0 = np.array(test_probs)[ethnicity_test == 0]
        true_labels_eth0 = np.array(test_labels)[ethnicity_test == 0]

        pred_probs_eth1 = np.array(test_probs)[ethnicity_test == 1]
        true_labels_eth1 = np.array(test_labels)[ethnicity_test == 1]

                # For overall (without ethnicity specification)
        cal_label_ls.append(test_labels)
        cal_prob_ls.append(test_probs)

        TN, FP, FN, TP = confusion_matrix(test_labels, np.array(test_probs).round()).ravel()

        ACC = (TP+TN)/(TP+FP+FN+TN)
        PRC = (TP)/(TP+FP)
        REC = (TP)/(TP+FN)
        F1M = (2*PRC*REC)/(PRC+REC)
        PPV = TP/(TP+FP)
        NPV = TN/(TN+FN)
        SEN = (TP)/(TP+FN)
        SPE = (TN)/(TN+FP)
        MCC = matthews_corrcoef(test_labels, np.array(test_probs).round())
        AUC = roc_auc_score(test_labels, test_probs)
        AP = average_precision_score(test_labels, test_probs)

        Overall_ACC_ls.append(ACC)
        Overall_PRC_ls.append(PRC)
        Overall_REC_ls.append(REC)
        Overall_F1M_ls.append(F1M)
        Overall_PPV_ls.append(PPV)
        Overall_NPV_ls.append(NPV)
        Overall_SEN_ls.append(SEN)
        Overall_SPE_ls.append(SPE)
        Overall_MCC_ls.append(MCC)
        Overall_AUC_ls.append(AUC)
        Overall_AP_ls.append(AP)

        fpr_lstm, tpr_lstm, thresholds_lstm = roc_curve(test_labels, test_probs)
        tprs_ls.append(interp(mean_fpr_ls, fpr_lstm, tpr_lstm))
        tprs_ls[-1][0] = 0.0
        roc_auc = auc(fpr_lstm, tpr_lstm)
        aucs_ls.append(roc_auc)

        no_skill = len([lab for lab in test_labels if lab == 1]) / len(test_labels)
        No_Skill_ls.append(no_skill)

        precision_lstm, recall_lstm, threshold_lstm = precision_recall_curve(test_labels, test_probs)
        prs_ls.append(interp(mean_recall_ls, precision_lstm, recall_lstm))
        pr_ap = auc(recall_lstm, precision_lstm)
        ap_ls.append(pr_ap)


        # For eth1
        cal_label_ls_eth1.append(true_labels_eth1)
        cal_prob_ls_eth1.append(pred_probs_eth1)

        TN_eth1, FP_eth1, FN_eth1, TP_eth1 = confusion_matrix(true_labels_eth1, pred_probs_eth1.round()).ravel()

        ACC_eth1 = (TP_eth1+TN_eth1)/(TP_eth1+FP_eth1+FN_eth1+TN_eth1)
        PRC_eth1 = (TP_eth1)/(TP_eth1+FP_eth1)
        REC_eth1 = (TP_eth1)/(TP_eth1+FN_eth1)
        F1M_eth1 = (2*PRC_eth1*REC_eth1)/(PRC_eth1+REC_eth1)
        PPV_eth1 = TP_eth1/(TP_eth1+FP_eth1)
        NPV_eth1 = TN_eth1/(TN_eth1+FN_eth1)
        SEN_eth1 = (TP_eth1)/(TP_eth1+FN_eth1)
        SPE_eth1 = (TN_eth1)/(TN_eth1+FP_eth1)
        MCC_eth1 = matthews_corrcoef(true_labels_eth1, pred_probs_eth1.round())
        AUC_eth1 = roc_auc_score(true_labels_eth1, pred_probs_eth1)
        AP_eth1 = average_precision_score(true_labels_eth1, pred_probs_eth1)

        Overall_ACC_ls_eth1.append(ACC_eth1)
        Overall_PRC_ls_eth1.append(PRC_eth1)
        Overall_REC_ls_eth1.append(REC_eth1)
        Overall_F1M_ls_eth1.append(F1M_eth1)
        Overall_PPV_ls_eth1.append(PPV_eth1)
        Overall_NPV_ls_eth1.append(NPV_eth1)
        Overall_SEN_ls_eth1.append(SEN_eth1)
        Overall_SPE_ls_eth1.append(SPE_eth1)
        Overall_MCC_ls_eth1.append(MCC_eth1)
        Overall_AUC_ls_eth1.append(AUC_eth1)
        Overall_AP_ls_eth1.append(AP_eth1)

        fpr_lstm_eth1, tpr_lstm_eth1, thresholds_lstm_eth1 = roc_curve(true_labels_eth1, pred_probs_eth1)
        tprs_ls_eth1.append(interp(mean_fpr_ls_eth1, fpr_lstm_eth1, tpr_lstm_eth1))
        tprs_ls_eth1[-1][0] = 0.0
        roc_auc_eth1 = auc(fpr_lstm_eth1, tpr_lstm_eth1)
        aucs_ls_eth1.append(roc_auc_eth1)

        no_skill_eth1 = len([lab for lab in true_labels_eth1 if lab == 1]) / len(true_labels_eth1)
        No_Skill_ls_eth1.append(no_skill_eth1)

        precision_lstm_eth1, recall_lstm_eth1, threshold_lstm_eth1 = precision_recall_curve(true_labels_eth1, pred_probs_eth1)
        prs_ls_eth1.append(interp(mean_recall_ls_eth1, precision_lstm_eth1, recall_lstm_eth1))
        pr_ap_eth1 = auc(recall_lstm_eth1, precision_lstm_eth1)
        ap_ls_eth1.append(pr_ap_eth1)

        # For eth0
        cal_label_ls_eth0.append(true_labels_eth0)
        cal_prob_ls_eth0.append(pred_probs_eth0)

        TN_eth0, FP_eth0, FN_eth0, TP_eth0 = confusion_matrix(true_labels_eth0, pred_probs_eth0.round()).ravel()

        ACC_eth0 = (TP_eth0+TN_eth0)/(TP_eth0+FP_eth0+FN_eth0+TN_eth0)
        PRC_eth0 = (TP_eth0)/(TP_eth0+FP_eth0)
        REC_eth0 = (TP_eth0)/(TP_eth0+FN_eth0)
        F1M_eth0 = (2*PRC_eth0*REC_eth0)/(PRC_eth0+REC_eth0)
        PPV_eth0 = TP_eth0/(TP_eth0+FP_eth0)
        NPV_eth0 = TN_eth0/(TN_eth0+FN_eth0)
        SEN_eth0 = (TP_eth0)/(TP_eth0+FN_eth0)
        SPE_eth0 = (TN_eth0)/(TN_eth0+FP_eth0)
        MCC_eth0 = matthews_corrcoef(true_labels_eth0, pred_probs_eth0.round())
        AUC_eth0 = roc_auc_score(true_labels_eth0, pred_probs_eth0)
        AP_eth0 = average_precision_score(true_labels_eth0, pred_probs_eth0)

        Overall_ACC_ls_eth0.append(ACC_eth0)
        Overall_PRC_ls_eth0.append(PRC_eth0)
        Overall_REC_ls_eth0.append(REC_eth0)
        Overall_F1M_ls_eth0.append(F1M_eth0)
        Overall_PPV_ls_eth0.append(PPV_eth0)
        Overall_NPV_ls_eth0.append(NPV_eth0)
        Overall_SEN_ls_eth0.append(SEN_eth0)
        Overall_SPE_ls_eth0.append(SPE_eth0)
        Overall_MCC_ls_eth0.append(MCC_eth0)
        Overall_AUC_ls_eth0.append(AUC_eth0)
        Overall_AP_ls_eth0.append(AP_eth0)

        fpr_lstm_eth0, tpr_lstm_eth0, thresholds_lstm_eth0 = roc_curve(true_labels_eth0, pred_probs_eth0)
        tprs_ls_eth0.append(interp(mean_fpr_ls_eth0, fpr_lstm_eth0, tpr_lstm_eth0))
        tprs_ls_eth0[-1][0] = 0.0
        roc_auc_eth0 = auc(fpr_lstm_eth0, tpr_lstm_eth0)
        aucs_ls_eth0.append(roc_auc_eth0)

        no_skill_eth0 = len([lab for lab in true_labels_eth0 if lab == 1]) / len(true_labels_eth0)
        No_Skill_ls_eth0.append(no_skill_eth0)

        precision_lstm_eth0, recall_lstm_eth0, threshold_lstm_eth0 = precision_recall_curve(true_labels_eth0, pred_probs_eth0)
        prs_ls_eth0.append(interp(mean_recall_ls_eth0, precision_lstm_eth0, recall_lstm_eth0))
        pr_ap_eth0 = auc(recall_lstm_eth0, precision_lstm_eth0)
        ap_ls_eth0.append(pr_ap_eth0)
        
        ##########################################################################
            
        IG = IntegratedGradients(temp_model)
        
        temp_model.train()
    
        for x, y, l in test_loader:

            x = x.to(device)
            y = y.to(device)
            l = l.to(device)

            y = y.long()


            ###
            seq_lengths, perm_idx = l.sort(0, descending=True)

            x = x[perm_idx]
            y = y[perm_idx]
            ###


            attribute = IG.attribute(x, additional_forward_args=seq_lengths, target=0)

            Total_attribute.append(attribute.detach().cpu().numpy())
            
        
        Total_features.append(X_data_test)

In [None]:
print("Ethnicity: 0")

print(f"Mean ACC: {np.mean(Overall_ACC_ls_eth0):.4f}")
print(f"Mean PRC: {np.mean(Overall_PRC_ls_eth0):.4f}")
print(f"Mean REC: {np.mean(Overall_REC_ls_eth0):.4f}")
print(f"Mean F1M: {np.mean(Overall_F1M_ls_eth0):.4f}")
print(f"Mean PPV: {np.mean(Overall_PPV_ls_eth0):.4f}")
print(f"Mean NPV: {np.mean(Overall_NPV_ls_eth0):.4f}")
print(f"Mean SEN: {np.mean(Overall_SEN_ls_eth0):.4f}")
print(f"Mean SPE: {np.mean(Overall_SPE_ls_eth0):.4f}")
print(f"Mean MCC: {np.mean(Overall_MCC_ls_eth0):.4f}")
print(f"Mean AUC: {np.mean(Overall_AUC_ls_eth0):.4f}")
print(f"Mean AP: {np.mean(Overall_AP_ls_eth0):.4f}")

print("\n")

print("Ethnicity: 1")

print(f"Mean ACC: {np.mean(Overall_ACC_ls_eth1):.4f}")
print(f"Mean PRC: {np.mean(Overall_PRC_ls_eth1):.4f}")
print(f"Mean REC: {np.mean(Overall_REC_ls_eth1):.4f}")
print(f"Mean F1M: {np.mean(Overall_F1M_ls_eth1):.4f}")
print(f"Mean PPV: {np.mean(Overall_PPV_ls_eth1):.4f}")
print(f"Mean NPV: {np.mean(Overall_NPV_ls_eth1):.4f}")
print(f"Mean SEN: {np.mean(Overall_SEN_ls_eth1):.4f}")
print(f"Mean SPE: {np.mean(Overall_SPE_ls_eth1):.4f}")
print(f"Mean MCC: {np.mean(Overall_MCC_ls_eth1):.4f}")
print(f"Mean AUC: {np.mean(Overall_AUC_ls_eth1):.4f}")
print(f"Mean AP: {np.mean(Overall_AP_ls_eth1):.4f}")

In [None]:
print("Ethnicity: 0")

print(f"Std ACC: {np.std(Overall_ACC_ls_eth0):.4f}")
print(f"Std PRC: {np.std(Overall_PRC_ls_eth0):.4f}")
print(f"Std REC: {np.std(Overall_REC_ls_eth0):.4f}")
print(f"Std F1M: {np.std(Overall_F1M_ls_eth0):.4f}")
print(f"Std PPV: {np.std(Overall_PPV_ls_eth0):.4f}")
print(f"Std NPV: {np.std(Overall_NPV_ls_eth0):.4f}")
print(f"Std SEN: {np.std(Overall_SEN_ls_eth0):.4f}")
print(f"Std SPE: {np.std(Overall_SPE_ls_eth0):.4f}")
print(f"Std MCC: {np.std(Overall_MCC_ls_eth0):.4f}")
print(f"Std AUC: {np.std(Overall_AUC_ls_eth0):.4f}")
print(f"Std AP: {np.std(Overall_AP_ls_eth0):.4f}")

print("\n")  # Add a blank line for separation

print("Ethnicity: 1")

print(f"Std ACC: {np.std(Overall_ACC_ls_eth1):.4f}")
print(f"Std PRC: {np.std(Overall_PRC_ls_eth1):.4f}")
print(f"Std REC: {np.std(Overall_REC_ls_eth1):.4f}")
print(f"Std F1M: {np.std(Overall_F1M_ls_eth1):.4f}")
print(f"Std PPV: {np.std(Overall_PPV_ls_eth1):.4f}")
print(f"Std NPV: {np.std(Overall_NPV_ls_eth1):.4f}")
print(f"Std SEN: {np.std(Overall_SEN_ls_eth1):.4f}")
print(f"Std SPE: {np.std(Overall_SPE_ls_eth1):.4f}")
print(f"Std MCC: {np.std(Overall_MCC_ls_eth1):.4f}")
print(f"Std AUC: {np.std(Overall_AUC_ls_eth1):.4f}")
print(f"Std AP: {np.std(Overall_AP_ls_eth1):.4f}")

In [None]:
metrics = ['ACC', 'PRC', 'REC', 'F1M', 'PPV', 'NPV', 'SEN', 'SPE', 'MCC', 'AUC', 'AP']

print("Results for balanced data:\n")

print(f"{'Metric':<7}{'Overall (Mean ± Std)':<30}{'Ethnicity 0 (Mean ± Std)':<30}{'Ethnicity 1 (Mean ± Std)':<30}")

print('-' * 96)


for metric in metrics:
    eth0_mean = np.nanmean(globals()[f'Overall_{metric}_ls_eth0'])
    eth0_std = np.nanstd(globals()[f'Overall_{metric}_ls_eth0'])
    eth1_mean = np.nanmean(globals()[f'Overall_{metric}_ls_eth1'])
    eth1_std = np.nanstd(globals()[f'Overall_{metric}_ls_eth1'])

    full_mean = np.nanmean(globals()[f'Overall_{metric}_ls'])
    full_std = np.nanstd(globals()[f'Overall_{metric}_ls'])
    
    print(f"{metric:<7}{full_mean:.4f} ± {full_std:<21.4f}{eth0_mean:.4f} ± {eth0_std:<21.4f}{eth1_mean:.4f} ± {eth1_std:.4f}")


In [None]:
import numpy as np
from scipy import stats

metrics = ['ACC', 'PRC', 'REC', 'F1M', 'PPV', 'NPV', 'SEN', 'SPE', 'MCC', 'AUC', 'AP']

print("Results for balanced data:\n")

print(f"{'Metric':<7}{'Full data (Mean [95% CI])':<35}{'Ethnicity 0 (Mean [95% CI])':<35}{'Ethnicity 1 (Mean [95% CI])':<35}")

print('-' * 77)

for metric in metrics:
    eth0_data = globals()[f'Overall_{metric}_ls_eth0']
    eth1_data = globals()[f'Overall_{metric}_ls_eth1']
    full_data = globals()[f'Overall_{metric}_ls']
    
    eth0_mean = np.nanmean(eth0_data)
    eth1_mean = np.nanmean(eth1_data)
    full_mean = np.nanmean(full_data)
    
    eth0_sem = stats.sem(eth0_data, nan_policy='omit')
    eth1_sem = stats.sem(eth1_data, nan_policy='omit')
    full_sem = stats.sem(full_data, nan_policy='omit')
    
    eth0_ci = eth0_mean - 1.96*eth0_sem, eth0_mean + 1.96*eth0_sem
    eth1_ci = eth1_mean - 1.96*eth1_sem, eth1_mean + 1.96*eth1_sem
    full_ci = full_mean - 1.96*full_sem, full_mean + 1.96*full_sem
    
    print(f"{metric:<7}{full_mean:.4f} [{full_ci[0]:.4f}, {full_ci[1]:.4f}]{' '*5}{eth0_mean:.4f} [{eth0_ci[0]:.4f}, {eth0_ci[1]:.4f}]{' '*5}{eth1_mean:.4f} [{eth1_ci[0]:.4f}, {eth1_ci[1]:.4f}]")

In [None]:
plt.figure(figsize=(12, 6))

# Plot for Ethnicity 0
plt.subplot(1, 2, 1)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='green', label='Random', alpha=.8)

mean_tpr_ls_eth0 = np.mean(tprs_ls_eth0, axis=0)
mean_tpr_ls_eth0[-1] = 1.0
mean_auc_ls_eth0 = auc(mean_fpr_ls_eth0, mean_tpr_ls_eth0)
std_auc_ls_eth0 = np.std(aucs_ls_eth0)
plt.plot(mean_fpr_ls_eth0, mean_tpr_ls_eth0, color='r', 
         label=r'LSTM Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc_ls_eth0, std_auc_ls_eth0), 
         lw=2, alpha=.8)

std_tpr_ls_eth0 = np.std(tprs_ls_eth0, axis=0)
tprs_upper_ls_eth0 = np.minimum(mean_tpr_ls_eth0 + std_tpr_ls_eth0, 1)
tprs_lower_ls_eth0 = np.maximum(mean_tpr_ls_eth0 - std_tpr_ls_eth0, 0)
plt.fill_between(mean_fpr_ls_eth0, tprs_lower_ls_eth0, tprs_upper_ls_eth0, color='grey', alpha=.2)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Ethnicity 0')
plt.legend(loc="lower right")

# Plot for Ethnicity 1
plt.subplot(1, 2, 2)
plt.plot([0, 1], [0, 1], linestyle='--', lw=2, color='green', label='Random', alpha=.8)

mean_tpr_ls_eth1 = np.mean(tprs_ls_eth1, axis=0)
mean_tpr_ls_eth1[-1] = 1.0
mean_auc_ls_eth1 = auc(mean_fpr_ls_eth1, mean_tpr_ls_eth1)
std_auc_ls_eth1 = np.std(aucs_ls_eth1)
plt.plot(mean_fpr_ls_eth1, mean_tpr_ls_eth1, color='b', 
         label=r'LSTM Mean ROC (AUC = %0.2f $\pm$ %0.2f)' % (mean_auc_ls_eth1, std_auc_ls_eth1), 
         lw=2, alpha=.8)

std_tpr_ls_eth1 = np.std(tprs_ls_eth1, axis=0)
tprs_upper_ls_eth1 = np.minimum(mean_tpr_ls_eth1 + std_tpr_ls_eth1, 1)
tprs_lower_ls_eth1 = np.maximum(mean_tpr_ls_eth1 - std_tpr_ls_eth1, 0)
plt.fill_between(mean_fpr_ls_eth1, tprs_lower_ls_eth1, tprs_upper_ls_eth1, color='grey', alpha=.2)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('ROC Curve - Ethnicity 1')
plt.legend(loc="lower right")

plt.tight_layout()
plt.show()

In [None]:
plt.figure(figsize=(12, 6))

# Plot for Ethnicity 0
plt.subplot(1, 2, 1)
plt.plot([0, 1], [np.mean(No_Skill_ls_eth0), np.mean(No_Skill_ls_eth0)], linestyle='--', lw=2, color='green', label='Random', alpha=.8)

mean_prs_ls_eth0 = np.mean(prs_ls_eth0, axis=0)
mean_ap_ls_eth0 = auc(mean_recall_ls_eth0, mean_prs_ls_eth0)
std_ap_ls_eth0 = np.std(ap_ls_eth0)
plt.plot(mean_recall_ls_eth0, mean_prs_ls_eth0, color='r', 
         label=r'Mean (AP = %0.3f $\pm$ %0.2f)' % (mean_ap_ls_eth0, std_ap_ls_eth0), 
         lw=2, alpha=.8)

std_prs_ls_eth0 = np.std(prs_ls_eth0, axis=0)
prs_upper_ls_eth0 = np.minimum(mean_prs_ls_eth0 + std_prs_ls_eth0, 1)
prs_lower_ls_eth0 = np.maximum(mean_prs_ls_eth0 - std_prs_ls_eth0, 0)
plt.fill_between(mean_recall_ls_eth0, prs_lower_ls_eth0, prs_upper_ls_eth0, color='grey', alpha=.2)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve - Ethnicity 0')
plt.legend(loc='upper right')

# Plot for Ethnicity 1
plt.subplot(1, 2, 2)
plt.plot([0, 1], [np.mean(No_Skill_ls_eth1), np.mean(No_Skill_ls_eth1)], linestyle='--', lw=2, color='green', label='Random', alpha=.8)

mean_prs_ls_eth1 = np.mean(prs_ls_eth1, axis=0)
mean_ap_ls_eth1 = auc(mean_recall_ls_eth1, mean_prs_ls_eth1)
std_ap_ls_eth1 = np.std(ap_ls_eth1)
plt.plot(mean_recall_ls_eth1, mean_prs_ls_eth1, color='b', 
         label=r'Mean (AP = %0.3f $\pm$ %0.2f)' % (mean_ap_ls_eth1, std_ap_ls_eth1), 
         lw=2, alpha=.8)

std_prs_ls_eth1 = np.std(prs_ls_eth1, axis=0)
prs_upper_ls_eth1 = np.minimum(mean_prs_ls_eth1 + std_prs_ls_eth1, 1)
prs_lower_ls_eth1 = np.maximum(mean_prs_ls_eth1 - std_prs_ls_eth1, 0)
plt.fill_between(mean_recall_ls_eth1, prs_lower_ls_eth1, prs_upper_ls_eth1, color='grey', alpha=.2)

plt.xlim([-0.05, 1.05])
plt.ylim([-0.05, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve - Ethnicity 1')
plt.legend(loc='upper right')

plt.tight_layout()
plt.show()

In [None]:
results_balanced = {
    'Overall_ACC_ls_eth1': Overall_ACC_ls_eth1,
    'Overall_ACC_ls_eth0': Overall_ACC_ls_eth0,
    'Overall_ACC_ls': Overall_ACC_ls,
    'Overall_PRC_ls_eth1': Overall_PRC_ls_eth1,
    'Overall_PRC_ls_eth0': Overall_PRC_ls_eth0,
    'Overall_PRC_ls': Overall_PRC_ls,
    'Overall_REC_ls_eth1': Overall_REC_ls_eth1,
    'Overall_REC_ls_eth0': Overall_REC_ls_eth0,
    'Overall_REC_ls': Overall_REC_ls,
    'Overall_F1M_ls_eth1': Overall_F1M_ls_eth1,
    'Overall_F1M_ls_eth0': Overall_F1M_ls_eth0,
    'Overall_F1M_ls': Overall_F1M_ls,
    'Overall_PPV_ls_eth1': Overall_PPV_ls_eth1,
    'Overall_PPV_ls_eth0': Overall_PPV_ls_eth0,
    'Overall_PPV_ls': Overall_PPV_ls,
    'Overall_NPV_ls_eth1': Overall_NPV_ls_eth1,
    'Overall_NPV_ls_eth0': Overall_NPV_ls_eth0,
    'Overall_NPV_ls': Overall_NPV_ls,
    'Overall_SEN_ls_eth1': Overall_SEN_ls_eth1,
    'Overall_SEN_ls_eth0': Overall_SEN_ls_eth0,
    'Overall_SEN_ls': Overall_SEN_ls,
    'Overall_SPE_ls_eth1': Overall_SPE_ls_eth1,
    'Overall_SPE_ls_eth0': Overall_SPE_ls_eth0,
    'Overall_SPE_ls': Overall_SPE_ls,
    'Overall_MCC_ls_eth1': Overall_MCC_ls_eth1,
    'Overall_MCC_ls_eth0': Overall_MCC_ls_eth0,
    'Overall_MCC_ls': Overall_MCC_ls,
    'Overall_AUC_ls_eth1': Overall_AUC_ls_eth1,
    'Overall_AUC_ls_eth0': Overall_AUC_ls_eth0,
    'Overall_AUC_ls': Overall_AUC_ls,
    'Overall_AP_ls_eth1': Overall_AP_ls_eth1,
    'Overall_AP_ls_eth0': Overall_AP_ls_eth0,
    'Overall_AP_ls': Overall_AP_ls,
    'No_Skill_ls_eth1': No_Skill_ls_eth1,
    'No_Skill_ls_eth0': No_Skill_ls_eth0,
    'No_Skill_ls': No_Skill_ls,
    'tprs_ls_eth1': tprs_ls_eth1,
    'tprs_ls_eth0': tprs_ls_eth0,
    'tprs_ls': tprs_ls,
    'aucs_ls_eth1': aucs_ls_eth1,
    'aucs_ls_eth0': aucs_ls_eth0,
    'aucs_ls': aucs_ls,
    'mean_fpr_ls_eth1': mean_fpr_ls_eth1,
    'mean_fpr_ls_eth0': mean_fpr_ls_eth0,
    'mean_fpr_ls': mean_fpr_ls,
    'prs_ls_eth1': prs_ls_eth1,
    'prs_ls_eth0': prs_ls_eth0,
    'prs_ls': prs_ls,
    'ap_ls_eth1': ap_ls_eth1,
    'ap_ls_eth0': ap_ls_eth0,
    'ap_ls': ap_ls,
    'mean_recall_ls_eth1': mean_recall_ls_eth1,
    'mean_recall_ls_eth0': mean_recall_ls_eth0,
    'mean_recall_ls': mean_recall_ls,
    'cal_prob_ls_eth1': cal_prob_ls_eth1,
    'cal_prob_ls_eth0': cal_prob_ls_eth0,
    'cal_prob_ls': cal_prob_ls,
    'cal_label_ls_eth1': cal_label_ls_eth1,
    'cal_label_ls_eth0': cal_label_ls_eth0,
    'cal_label_ls': cal_label_ls,
    'Total_attribute': Total_attribute,
    'Total_features': Total_features
}

# Bias results

In [None]:
def calculate_difference(results, metric):
    eth0 = results[f'{metric}_ls_eth0']
    eth1 = results[f'{metric}_ls_eth1']
    return [e0 - e1 for e1, e0 in zip(eth1, eth0)]

In [None]:
metrics_to_compare = ['Overall_ACC', 'Overall_AUC', 'Overall_MCC']

for metric in metrics_to_compare:
    real_diff = calculate_difference(results_real, metric)
    balanced_diff = calculate_difference(results_balanced, metric)
    
    print(f"\nComparing {metric}:")
    print(f"Unbalanced mean difference: {np.mean(real_diff):.4f}")
    print(f"Balanced mean difference: {np.mean(balanced_diff):.4f}")
    


In [None]:
import numpy as np
from scipy import stats

def calculate_difference(results, metric):
    eth0 = results[f'{metric}_ls_eth0']
    eth1 = results[f'{metric}_ls_eth1']
    return [e0 - e1 for e1, e0 in zip(eth1, eth0)]

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), stats.sem(a)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

metrics_to_compare = ['Overall_ACC', 'Overall_AUC', 'Overall_MCC']

for metric in metrics_to_compare:
    real_diff = calculate_difference(results_real, metric)
    balanced_diff = calculate_difference(results_balanced, metric)
    
    real_mean, real_ci_low, real_ci_high = mean_confidence_interval(real_diff)
    balanced_mean, balanced_ci_low, balanced_ci_high = mean_confidence_interval(balanced_diff)
    
    print(f"\nComparing {metric}:")
    print(f"Unbalanced mean difference: {real_mean:.4f} (95% CI: {real_ci_low:.4f} to {real_ci_high:.4f})")
    print(f"Balanced mean difference: {balanced_mean:.4f} (95% CI: {balanced_ci_low:.4f} to {balanced_ci_high:.4f})")
    
    if abs(balanced_mean) < abs(real_mean):
        print("Balanced method reduced bias.")
    else:
        print("Balanced method did not reduce bias.")
    
    if not (real_ci_low <= balanced_mean <= real_ci_high):
        print("The difference is statistically significant at the 95% confidence level.")
    else:
        print("The difference is not statistically significant at the 95% confidence level.")

In [None]:
def calculate_difference(results, metric):
    eth0 = results[f'{metric}_ls_eth0']
    eth1 = results[f'{metric}_ls_eth1']
    
    # Function to average every n elements
    def average_groups(lst):
        n = 3
        return [sum(lst[i:i+n]) / n for i in range(0, len(lst), n)]
    
    # Average the results for each ethnicity
    avg_eth0 = average_groups(eth0)
    avg_eth1 = average_groups(eth1)
    
    # Calculate the difference between averaged results
    return [e0 - e1 for e1, e0 in zip(avg_eth1, avg_eth0)]

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), stats.sem(a)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

metrics_to_compare = ['Overall_ACC', 'Overall_AUC', 'Overall_MCC']

for metric in metrics_to_compare:
    real_diff = calculate_difference(results_real, metric)
    balanced_diff = calculate_difference(results_balanced, metric)
    
    real_mean, real_ci_low, real_ci_high = mean_confidence_interval(real_diff)
    balanced_mean, balanced_ci_low, balanced_ci_high = mean_confidence_interval(balanced_diff)
    
    print(f"\nComparing {metric}:")
    print(f"Unbalanced mean difference: {real_mean:.4f} (95% CI: {real_ci_low:.4f} to {real_ci_high:.4f})")
    print(f"Balanced mean difference: {balanced_mean:.4f} (95% CI: {balanced_ci_low:.4f} to {balanced_ci_high:.4f})")
    
    if abs(balanced_mean) < abs(real_mean):
        print("Balanced method reduced bias.")
    else:
        print("Balanced method did not reduce bias.")
    
    if not (real_ci_low <= balanced_mean <= real_ci_high):
        print("The difference is statistically significant at the 95% confidence level.")
    else:
        print("The difference is not statistically significant at the 95% confidence level.")

In [None]:
metrics_to_compare = ['Overall_ACC', 'Overall_AUC', 'Overall_MCC']

for metric in metrics_to_compare:
    real_diff = calculate_difference(results_real, metric)
    balanced_diff = calculate_difference(results_balanced, metric)
    
    print(f"\nComparing {metric}:")
    print(f"Unbalanced mean difference: {np.mean(real_diff):.4f}")
    print(f"Balanced mean difference: {np.mean(balanced_diff):.4f}")
    

In [None]:
def calculate_difference_full(results, metric):
    full = results[f'{metric}_ls']
    eth1 = results[f'{metric}_ls_eth1']
    
    # Function to average every n elements
    def average_groups(lst):
        n = 3
        return [sum(lst[i:i+n]) / n for i in range(0, len(lst), n)]
    
    # Average the results 
    avg_full = average_groups(full)
    avg_eth1 = average_groups(eth1)
    
    # Calculate the difference between averaged results
    return [e0 - e1 for e1, e0 in zip(avg_eth1, avg_full)]

def mean_confidence_interval(data, confidence=0.95):
    a = 1.0 * np.array(data)
    n = len(a)
    m, se = np.mean(a), stats.sem(a)
    h = se * stats.t.ppf((1 + confidence) / 2., n-1)
    return m, m-h, m+h

metrics_to_compare = ['Overall_ACC', 'Overall_AUC', 'Overall_MCC']

print("Full dataset vs minority class \n")

for metric in metrics_to_compare:
    real_diff = calculate_difference_full(results_real, metric)
    balanced_diff = calculate_difference_full(results_balanced, metric)
    
    real_mean, real_ci_low, real_ci_high = mean_confidence_interval(real_diff)
    balanced_mean, balanced_ci_low, balanced_ci_high = mean_confidence_interval(balanced_diff)
    
    print(f"\nComparing {metric}:")
    print(f"Unbalanced mean difference: {real_mean:.4f} (95% CI: {real_ci_low:.4f} to {real_ci_high:.4f})")
    print(f"Balanced mean difference: {balanced_mean:.4f} (95% CI: {balanced_ci_low:.4f} to {balanced_ci_high:.4f})")
    
    if abs(balanced_mean) < abs(real_mean):
        print("Balanced method reduced bias.")
    else:
        print("Balanced method did not reduce bias.")
    
    if not (real_ci_low <= balanced_mean <= real_ci_high):
        print("The difference is statistically significant at the 95% confidence level.")
    else:
        print("The difference is not statistically significant at the 95% confidence level.")