In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import random
import seaborn as sns
from dateutil.rrule import rrule, SECONDLY, MINUTELY, HOURLY
from collections import defaultdict
from sklearn.preprocessing import normalize
from sklearn.feature_selection import SelectFromModel
from sklearn.svm import LinearSVC
from sklearn.ensemble import ExtraTreesClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.feature_selection import RFE

from preprocessing.load_raw_vital_signs import *
from preprocessing.heart_rate_variability import calculate_hrv

RAW_VITAL_DATA_PATH = "./DATA/Raw Data/filtered_df_removed_nan_files.parquet"
CLINICAL_DATA = './DATA/Clean Data/IMPALA_Clinical_Data_202308211019_Raw.csv'

### Select patients based on criteria

In [None]:

def read_clinical_df(path):
    """ Load clinical data into a Pandas DataFrame. """
    df = pd.read_csv(path, low_memory=False)
    return df[['record_id', 'dis_outcome']]

def select_patients_by_los(df, lower_bound_hours, upper_bound_hours, n_selected_patients=None):
    """
    Select a number of patients based on their length of stay.
    """

    seed = 42
    random.seed(seed)
    np.random.seed(seed)

    alive_ids, died_ids = [], []

    for patient_id, series in df.groupby('record_id'):

        if series.shape[0] >= np.ceil(lower_bound_hours/4) and \
                series.shape[0] <= np.ceil(upper_bound_hours/4):
            
            if series['dis_outcome'].iloc[0] == 1:
                alive_ids.append(patient_id)
            elif series['dis_outcome'].iloc[0] == 2:
                died_ids.append(patient_id)

    print(f'====== Select patients that fit criteria (min {lower_bound_hours}h, max {upper_bound_hours}h) ========')
    print(f'- Number of alive patients: {len(alive_ids)}')
    print(f'- Number of deceased patients: {len(died_ids)}\n')

    max_n = min(len(alive_ids), len(died_ids))
    if not n_selected_patients or n_selected_patients > max_n:
        n_selected_patients = max_n

    alive_ids = np.random.choice(alive_ids, size=n_selected_patients, replace=False)
    died_ids = np.random.choice(died_ids, size=n_selected_patients, replace=False)

    print(f'====== Randomly select {n_selected_patients} patients from both selected groups =======')
    print(f'Alive patients\n{alive_ids}\n')
    print(f'Deceased patients\n{died_ids}')

    return alive_ids, died_ids


In [None]:

df = read_clinical_df(CLINICAL_DATA)

alive_ids_short, died_ids_short = select_patients_by_los(df, 1, 13)


### Load data

In [None]:

### Read data
# data_short = read_raw_vital_signs(RAW_VITAL_DATA_PATH, batch_size=10000, patient_id=list(np.concatenate((alive_ids_short, died_ids_short))))

# ### Save data
# save_patient_dict(data_short, './DATA/Raw Data/raw_patient_dict_short')

# ### Load data
# data_long = load_patient_dict('./DATA/Raw Data/raw_patient_dict_p30')

print(data_short.keys())
print(data_long.keys())


### Heart rate variability

| Frequency-domain | Time-domain | Non linear-domain |
|:---:|:---:|:---:|
| lf, hf, lf_hf_ratio, lfnu, hfnu, total_power, vlf | mean_nni, sdnn, sdsd, nni_50, pnni_50, nni_20, pnni_20, rmssd, median_nni, range_nni, cvsd, cvnni, mean_hr, max_hr, min_hr, std_hr | csi, cvi, Modified_csi |

In [None]:

def split_data_into_windows(df, time_unit='m', time_freq=15):
    """
    Clean and split the data into windows.
    :param df: Pandas DataFrame containing the data indexed on timestamps.
    :param time_unit: time unit of the data window, e.g. s (seconds), m (minutes).
    :param time_freq: number of time units in the data window.
    """

    # Clean data
    df = df[['ECGHR', 'datetime']]
    df = df.set_index('datetime')
    df = df.sort_index()

    # Split data
    rrule_time = {'h' : HOURLY, 'm' : MINUTELY, 's' : SECONDLY}
    windows = []
    timestamps = []
    
    for start in rrule(freq=rrule_time[time_unit], interval=time_freq,
                       dtstart=df.index[0], until=df.index[-1]):
        
        end = start + pd.Timedelta(time_freq, unit=time_unit)
        idx = df.index.to_series().between(start, end)
        if len(idx) > 0:
            window = df[idx]

            windows.append(window.values.squeeze())
            timestamps.append(end)

    return windows, timestamps


def get_hrv_scores(data, normalize_data=False):
    """
    Calculate the heart rate variability of all patients.
    """

    hrv_data = defaultdict(list)
    time_data = defaultdict(list)
    threshold = 280 # 5 minutes of data with a error margin of 20 seconds
    error = 0
    total = 0

    for patient_id, df in data.items():

        if df.shape[0] == 0:
            print(f'Patient {patient_id} has not data entries')
            continue

        windows, timestamps = split_data_into_windows(df)

        for window, time in zip(windows, timestamps):

            total += 1

            if np.where(np.isnan(window), 0, 1).sum() >= threshold:

                hrv_features = calculate_hrv(window)

                # Check if hrv results only contain valid scores
                if hrv_features and ~np.isnan(np.array(list(hrv_features.values()))).any():
                    hrv_data[patient_id].append(list(hrv_features.values()))
                    time_data[patient_id].append(time)
                    continue

            error += 1

    # Normalize each feature to be between 0 and 1
    if normalize_data:
        hrv_data = {k : normalize(np.array(v), axis=0, norm='max') for k, v in hrv_data.items()}
    else:
        hrv_data = {k : np.array(v) for k, v in hrv_data.items()}

    time_data = {k : np.array(v) for k, v in time_data.items()}

    print(f'{round(error / total * 100, 2)}% of windows were invalid')

    return hrv_data, time_data


In [None]:

NAMES = np.array(['lf', 'hf', 'lf_hf_ratio', 'lfnu', 'hfnu', 'total_power', 'vlf',
                  'mean_nni', 'sdnn', 'sdsd', 'nni_50', 'pnni_50', 'nni_20', 'pnni_20',
                  'rmssd', 'median_nni', 'range_nni', 'cvsd', 'cvnni', 'mean_hr',
                  'max_hr', 'min_hr', 'std_hr', 'csi', 'cvi', 'Modified_csi'])

hrv_data_long, time_data_long = get_hrv_scores(data_long, normalize_data=True)
hrv_data_short, time_data_short = get_hrv_scores(data_short, normalize_data=True)


#### Data selection

In [None]:

def split_data_per_category(hrv_data, alive_ids, died_ids):
    """
    Split the hrv data between alive and died patients.
    """

    data_alive = np.concatenate([v for k, v in hrv_data.items() if k in alive_ids])
    data_died = np.concatenate([v for k, v in hrv_data.items() if k in died_ids])
    data_all = np.concatenate((data_alive, data_died))

    return data_alive, data_died, data_all


def choose_last_hours(data, time_data, n_hours=12):
    """
    Choose the final N hours of the data.
    """

    new_data = dict()
    new_time = dict()

    for (patient_id, scores), time in zip(data.items(), time_data.values()):

        end = time[-1]
        start = end - pd.Timedelta(n_hours, unit='h')
        idx = [True if d >= start and d <= end else False for d in time]

        new_data[patient_id] = scores[idx]
        new_time[patient_id] = time[idx]
    
    return new_data, new_time


In [None]:

alive_ids_long = ['Z-H-0182', 'Z-H-0155', 'Z-H-0336', 'Z-H-0290', 'B-S-0007', 'Z-H-0373',
             'B-N-0063', 'B-S-0159', 'Z-H-0376', 'Z-H-0130', 'Z-H-0144', 'B-S-0166',
             'Z-H-0044', 'Z-H-0173', 'B-S-0151']

died_ids_long = ['Z-H-0114', 'B-N-0080', 'B-N-0084', 'Z-H-0348', 'Z-H-0198', 'Z-H-0032',
            'Z-H-0350', 'Z-H-0185', 'Z-H-0054', 'B-N-0058', 'B-S-0242', 'Z-H-0120',
            'Z-H-0308', 'B-S-0292', 'Z-H-0116']

# All data
# data_alive, data_died, data_all = split_data_per_category(hrv_data, alive_ids, died_ids)

# Selected data
hrv_data_12h_long, time_data_12h_long = choose_last_hours(hrv_data_long, time_data_long, n_hours=12)
data_alive_12h_long, data_died_12h_long, data_all_12h_long = split_data_per_category(hrv_data_12h_long, alive_ids_long, died_ids_long)

hrv_data_12h_short, time_data_12h_short = choose_last_hours(hrv_data_short, time_data_short, n_hours=12)
data_alive_12h_short, data_died_12h_short, data_all_12h_short = split_data_per_category(hrv_data_12h_short, alive_ids_short, died_ids_short)


### Feature selection

#### Sklearn methods

In [None]:

def feature_selection(data_alive, data_died, n_seeds=10, visualize=True):
    """
    Perform feature selection based on 3 classifiers and 2 methods from Sklearn.
    """

    # Prepare data
    y_alive = np.ones(data_alive.shape[0])
    y_died = np.zeros(data_died.shape[0])
    X = np.concatenate((data_alive, data_died))
    y = np.concatenate((y_alive, y_died))

    feature_counter = np.zeros(X.shape[1])

    for seed in range(n_seeds):

        classifiers = [LogisticRegression(max_iter=10000, random_state=seed),
                       LinearSVC(dual='auto',max_iter=10000, random_state=seed),
                       ExtraTreesClassifier(random_state=seed)]
        
        for clf in classifiers:
            # Select using feature importance
            clf_fit = clf.fit(X, y)
            clf_model = SelectFromModel(clf_fit, prefit=True, max_features=10)
            feature_counter += clf_model.get_support().astype(int)

            # Recursive Feature Elimination
            clf_rfe = RFE(estimator=clf, n_features_to_select=10).fit(X, y)
            feature_counter += clf_rfe.get_support().astype(int)
    
    
    if visualize:
        plt.bar(range(7), feature_counter[:7])
        plt.bar(range(7, 23), feature_counter[7:23])
        plt.bar(range(23, 26), feature_counter[23:26])
        plt.xticks(range(len(NAMES)), NAMES, rotation=-90)
        plt.ylabel('N times in the 10')
        plt.title(f'Features selection (3 classifiers, 2 selectors, {n_seeds} seeds)')
        plt.legend(['Frequency-domain', 'Time-domain', 'Non Linear-domain'])
        plt.tight_layout()
        plt.show()

    return np.argsort(-feature_counter)


In [None]:

sorted_features = feature_selection(data_alive_12h_long, data_died_12h_long)
print(sorted_features)

sorted_features = feature_selection(data_alive_12h_short, data_died_12h_short)
print(sorted_features)

sorted_features = feature_selection(data_died_12h_long, data_died_12h_short)
print(sorted_features)

sorted_features = feature_selection(data_alive_12h_long, data_alive_12h_short)
print(sorted_features)


#### Histograms

In [None]:

def plot_histograms(alive, died):

    fig = plt.figure(figsize=(10, 10))
    j = 0

    for i, feature in enumerate(NAMES):

        if i in [2, 3, 7, 9, 14, 15, 20, 24]:

            plt.subplot(3, 3, j+1)
            j += 1

            sns.histplot(alive[:, i],
                        kde=False, bins=50, color='tab:blue')
            sns.histplot(died[:, i],
                        kde=False, bins=50, color='tab:orange')
            
            # sns.histplot(alive[:, i][alive[:, i] <= 0.2],
            #              kde=True, bins=50, color='tab:blue')
            # sns.histplot(died[:, i][died[:, i] <= 0.2],
            #              kde=True, bins=50, color='tab:orange')
            # plt.xlim(-0.01, 0.21)

            plt.title(feature)
            plt.legend(['Alive', 'Deceased'])

    plt.tight_layout()
    plt.show()


In [None]:

plot_histograms(data_alive_12h, data_died_12h)
# plot_histograms(data_alive, data_died)


#### Scores over time

In [None]:
import datetime

time_data_1 = time_data_12h_long['Z-H-0114']
time_data_2 = time_data_12h_long['Z-H-0182']
hrv_data_1 = hrv_data_12h_long['Z-H-0114']
hrv_data_2 = hrv_data_12h_long['Z-H-0182']

feature_idx = range(5) #[2, 3, 7, 9, 14, 15, 20, 24]

# Make times equal
new_time_died = [(datetime.datetime.now() + (t - time_data_1[0])) for t in time_data_1]
new_time_alive = [(datetime.datetime.now() + (t - time_data_2[0])) for t in time_data_2]

# Plot graphs
fig, ax = plt.subplots(len(feature_idx), 1, figsize=(15, len(feature_idx)*3), sharex=True)
ticks = [0, int(len(new_time_died)/4), int(len(new_time_died)/2), int(3*(len(new_time_died)/4)), -1]

plt.xticks(np.array(new_time_died)[ticks], [12, 9, 6, 3, 0], rotation=-45, fontsize=11)
fig.suptitle(f"Summary last 12 hours", weight='bold', fontsize=18)

for i, id_ in enumerate(feature_idx):
    ax[i].plot(new_time_died, hrv_data_1[:, id_], '-o', label=NAMES[id_], color='red')
    ax[i].plot(new_time_alive, hrv_data_2[:, id_], '-o', label=NAMES[id_], color='green')
    ax[i].grid(True)
    ax[i].legend(loc='upper left', bbox_to_anchor=(1,1))

# ax[0].plot(new_time_died, hrv_data_12h['Z-H-0114'][:, 3], '-o', label='lfnu', color='red')
# ax[1].plot(new_time_died, hrv_data_12h['Z-H-0114'][:, 24], '-o', label='cvi', color='red')
# ax[2].plot(new_time_died, hrv_data_12h['Z-H-0114'][:, 19], '-o', label='mean_hr', color='red')
# ax[0].plot(new_time_alive, hrv_data_12h['Z-H-0182'][:, 3], '-o', label='lfnu', color='green')
# ax[1].plot(new_time_alive, hrv_data_12h['Z-H-0182'][:, 24], '-o', label='cvi', color='green')
# ax[2].plot(new_time_alive, hrv_data_12h['Z-H-0182'][:, 19], '-o', label='mean_hr', color='green')

# ax[0].grid(True)
# ax[1].grid(True)
# ax[2].grid(True)
# fig.subplots_adjust(wspace=.3)
# ax[0].legend(loc='upper left', bbox_to_anchor=(1,1))
# ax[1].legend(loc='upper left', bbox_to_anchor=(1,1))
# ax[2].legend(loc='upper left', bbox_to_anchor=(1,1))

# ax[0].axhspan(100, 150, facecolor='green', alpha=0.2)
# ax[0].axhspan(150, 250, facecolor='orange', alpha=0.2)
# ax[0].axhspan(250, 330, facecolor='red', alpha=0.2)
# ax[1].axhspan(1.5, 3, facecolor='green', alpha=0.2)
# ax[1].axhspan(3, 4, facecolor='orange', alpha=0.2)
# ax[1].axhspan(4, 5, facecolor='red', alpha=0.2)
# ax[2].axhspan(0, 1, facecolor='green', alpha=0.2)
# ax[2].axhspan(1, 1.5, facecolor='orange', alpha=0.2)
# ax[2].axhspan(1.5, 3, facecolor='red', alpha=0.2)
