In [None]:
import pandas as pd
import numpy as np
from fastparquet import ParquetFile
import pickle

import torch
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, normalize

VITAL_DATA_PATH = "./IMPALA_Clinical_Data/Clean Data/Vital Signs Data/df_merged.parquet"

#### Load and preprocess the data

In [None]:

def read_vital_data(path):
    """ Read vital sign data from Parquet file and convert it to Pandas DataFrame. """
    pf = ParquetFile(path)
    df = pf.to_pandas()
    return df


def extract_vital_sign(df):
    """
    Extract vital sign columns
    """

    df = df[df.columns[df.columns.str.startswith(tuple(['ECGHR', 'ECGRR', 'SPO2', 'record_id']))]]
    return df


def convert_to_date(df):
    """ Convert date indeces to Pandas Timestamps. """

    df.index = pd.to_datetime(df.index)
    return df


def assign_patient_to_nan(df, time_interval=48):
    """
    All rows where 'record_id' is NaN are assigned to a patient. Loop through
    the DataFrame until a NaN appears. Assign previously found record_id to
    current row iff the time between the current and previous row is less than
    48 hours, otherwise assign row to next record_id.
    """

    record_id = None
    lower_i = 0

    for i, row in enumerate(df.iterrows()):

        if i == len(df)-1: # Check if current row is the last row
            df['record_id'][lower_i:i+1] = record_id # Process previous patient df

        curr_record_id = row[1]['record_id']

        if pd.isna(record_id): # If no record_id is encountered yet
            if pd.isna(curr_record_id): # If current record_id is NaN
                continue
            else: # Else current record_id is saved
                record_id = curr_record_id

        else: # If a previous record_id already exists
            if record_id == curr_record_id: # If previously found and current record_id match
                continue
                                                
            elif not pd.isna(curr_record_id) and record_id != curr_record_id: # If new record_id is encountered
                df['record_id'][lower_i:i] = record_id # Process previous patient df
                lower_i = i
                record_id = curr_record_id

            elif pd.isna(curr_record_id): # If current record_id is NaN

                # If time interval > 48 hours, new patient is encountered.
                if np.abs(df.iloc[[i-1]].index[0] - row[0]) > pd.Timedelta(hours=time_interval):
                    df['record_id'][lower_i:i] = record_id # Process previous patient df
                    lower_i = i
                    record_id = None

    return df


def sort_dates_per_patient(df):
    """ Sort the dates into chronological order (per patient). """

    list_of_dfs = [df.sort_index() for _, df in df.groupby(df['record_id'],
                                                           observed=True,
                                                           sort=False)]

    return pd.concat(list_of_dfs)


def remove_nan_rows(df):
    """
    Remove rows that contain a nan value.
    """

    nan_idx = list(np.where(df.iloc[:, :-1].isna().any(axis=1) == True)[0])

    df = df.reset_index() 
    df = df.drop(index=nan_idx, axis=0)
    df = df.set_index('index')

    return df


In [None]:
# Read data
vital_df = read_vital_data(VITAL_DATA_PATH)

# Remove recruitment, daily and discharge information
vital_df = extract_vital_sign(vital_df)

# Convert date indeces to panda Timestamps
vital_df = convert_to_date(vital_df)

# Fill in record_id where it is missing
vital_df = assign_patient_to_nan(vital_df)

# Per patient, sort the data chronologically
vital_df = sort_dates_per_patient(vital_df)

# Remove rows that contain any NaN values
vital_df = remove_nan_rows(vital_df)
print(vital_df.shape)


#### Normalize vital signs

In [None]:

def normalize_vital_signs(df, method='standardize'):
    """
    Normalize or standardize the vital signs data.
    """

    data = df.values[:, :-1]

    if method == 'standardize':
        scaler = StandardScaler()
        scaler.fit(data)
        new_data = scaler.transform(data)


    elif method == 'normalize':
        new_data = normalize(data)

    else:
        print('Method unknown')

    return new_data


In [None]:

vital_data = normalize_vital_signs(vital_df, method='standardize')


#### Apply sliding window

In [None]:

def load_cie_data(filename):
    """ Read pickle file containing the CIE information. """

    with open(filename, 'rb') as f:
        cie_dict = pickle.load(f)

    return cie_dict


In [None]:

cie_dict = load_cie_data('saved_CIE')
cpr, picu, death = 0, 0, 0

for d in cie_dict.values():
    cpr += len(d['cpr'])
    picu += len(d['picu'])
    death += len(d['death'])

print("=== Critical Illness Events ===")
print(f"- CPR  : {cpr}")
print(f"- PICU : {picu}")
print(f"- Death: {death}")
print(f"- Total: {cpr + picu + death}")


In [None]:

def is_cie(cie_dict, record_id, predictive_window):
    """ Check whether a CIE occured within the predictive window. """

    cie = [0, 0, 0]

    # Check whether a CIE even occured
    if len([1 for v in cie_dict[record_id].values() if v == []]) == 3:
        return cie
    
    for i, timepoints in enumerate(cie_dict[record_id].values()):
        cie[i] += sum([1 for t in predictive_window if t in timepoints])

    return cie



def sliding_window(vital_df, cie_dict, sample_window_hours=4, predictive_window_hours=1):
    """
    Apply a sliding window over the data and label each window with whether a
    CIE occured (and what type). Sample_window indicates the amount of data
    samples are picked and predictive_window is the amount of time that is looked
    into the future to see whether a CIE falls in that window.
    """

    assert sample_window_hours > 0
    assert predictive_window_hours > 0

    X = []
    y = []
    removed_dfs = 0

    for record_id, df in vital_df.groupby(['record_id'], observed=False):
        
        if df.shape[0] < sample_window_hours + predictive_window_hours:
            # Skip dataframe if it is too small
            removed_dfs += 1
            continue

        # Turn df into numpy arrays
        df = df.drop(['record_id'], axis=1)
        timepoints = df.index.values
        data = np.array(df.values)

        # Creating time range
        begin_time, end_time = timepoints[0], timepoints[-1]
        time_range = np.arange(begin_time, end_time+1, np.timedelta64(1, "h"))
        
        if len(time_range) < sample_window_hours + predictive_window_hours:
            removed_dfs += 1
            continue

        time_windows = np.lib.stride_tricks.sliding_window_view(
            time_range,
            sample_window_hours + predictive_window_hours
        )

        for window in time_windows:
            sample_window = window[:sample_window_hours]
            predictive_window = window[-predictive_window_hours:]

            # Check whether a cie occured
            cies = is_cie(cie_dict, record_id[0], predictive_window)

            # Fill data values for time in data index, pad with -1 otherwise
            data_sample = np.array([data[i].astype(float) if t in timepoints else \
                                    -np.ones(data[i].shape) \
                                    for i, t in enumerate(sample_window)])

            # Add data sample to the dataset
            X.append(data_sample)
            y.append(cies)


    print(f'{removed_dfs} DataFrames were too small')
    X = np.array(X) # shape: data samples, sample window size, dimensions
    y = np.array(y) # shape: data samples, number of CIE (outputs)

    return X, y


In [None]:

print(vital_df.shape)
X, y = sliding_window(vital_df, cie_dict, sample_window_hours=8, predictive_window_hours=4)
print(X.shape, y.shape)


#### Create PyTorch dataloader

In [None]:

class VitalSignDataset(Dataset):

    def __init__(self, X, y):
        self.X = torch.Tensor(X)
        self.y = torch.Tensor(y)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx, :, :], self.y[idx, :]


def create_dataloaders(X, y, batch_size=32, seed=42):
    """
    Shuffle the data, split it into train, val and test set and create dataloaders.

    NOTE: train, val and test are split, 0.7, 0.1, 0.2.
    """

    X_train, X_val, y_train, y_val = train_test_split(X, y,
                                                  test_size=0.3,
                                                  random_state=seed)

    X_val, X_test, y_val, y_test = train_test_split(X_val, y_val,
                                                    test_size=0.66,
                                                    random_state=seed)
    
    train_set = VitalSignDataset(X_train, y_train)
    val_set = VitalSignDataset(X_val, y_val)
    test_set = VitalSignDataset(X_test, y_test)

    train_dataloader = DataLoader(train_set, batch_size=batch_size, shuffle=True, drop_last=False)
    val_dataloader = DataLoader(val_set, batch_size=batch_size, shuffle=True, drop_last=False)
    test_dataloader = DataLoader(test_set, batch_size=batch_size, shuffle=True, drop_last=False)

    return train_dataloader, val_dataloader, test_dataloader


In [None]:

train_dataloader, val_dataloader, test_dataloader = create_dataloaders(X, y)

sample = next(iter(test_dataloader))

print(sample[0][0])
print(sample[1][0])
