# Miscellaneous Analyses for Prediction with RNNModel by Darts

## Which chunk IDs were predicted?

This extraction is performed to compare if the same chunks are considered in the ARIMA(X) and the RNNModel approach. It assumes that the prediction series have the following naming convention: `pred_series_{parameter}_{n_chunks}_window{window_nr}.pickle`.

In [None]:
from collections import defaultdict
import os
import pickle5 as pickle

# Define variables to adjust
n_chunks = 2000
style = 'all'

path = f'../../data/chunk_ids/{n_chunks}_chunks/{style}'
chunk_ids = defaultdict(list)

for file in os.listdir(path):
    if os.path.isfile(os.path.join(path, file)) and file.startswith('pred_series'):

        # Load current prediction series
        current_pred_series_f = open(f'{path}/{file}', 'rb')
        current_pred_series = pickle.load(current_pred_series_f)
        current_pred_series_f.close()

        # Extract substrings
        parameter = file.split('_')[2]
        n_chunks = file.split('_')[3]

        # Add partial list of chunk IDs to dict
        current_chunk_ids = list(current_pred_series.keys())
        if f'{parameter}_{n_chunks}' not in chunk_ids:
            chunk_ids[f'{parameter}_{n_chunks}'] = list()
        chunk_ids[f'{parameter}_{n_chunks}'] = chunk_ids[f'{parameter}_{n_chunks}'] + current_chunk_ids

# Combine partial lists of windows to final list and save it
for key in chunk_ids.keys():
    current_chunk_ids_f = open(f'{path}/chunk_ids_{key}.pickle', 'wb')

    current_chunk_ids = chunk_ids[key]
    print(f'{key.split("_")[0]} with {key.split("_")[1]} chunks: {len(current_chunk_ids)} chunks for prediction')

    pickle.dump(current_chunk_ids, current_chunk_ids_f, protocol=pickle.HIGHEST_PROTOCOL)
    current_chunk_ids_f.close()

In [None]:
import pandas as pd

# Check if combined chunk IDs match expected ones
for parameter in ['HR', 'BP', 'O2']:
    # Extract list with chunk IDs from prediction
    current_chunk_ids_f = open(f'{path}/chunk_ids_{parameter}_{n_chunks}.pickle', 'rb')
    current_chunk_ids_pred = pickle.load(current_chunk_ids_f)
    current_chunk_ids_f.close()

    # Extract list with expected chunk IDs
    current_chunk_ids_original = list()
    resampled_chunks = pd.read_parquet(f'../../data/resampling/resample_output_{parameter}_first{n_chunks}.parquet',
                                       engine='pyarrow')

    for chunk_id in pd.unique(resampled_chunks.CHUNK_ID_FILLED_TH):
        current_series = resampled_chunks[resampled_chunks['CHUNK_ID_FILLED_TH'] == chunk_id]

        if len(current_series) > 12:
            current_chunk_ids_original.append(chunk_id)

    # Inform if chunk IDs from prediction don't match expected ones
    if set(current_chunk_ids_pred) != set(current_chunk_ids_original):
        print(f'There are different chunk IDs than expected for {parameter} with {n_chunks} chunks')

## Which chunks are affected by the ValueError?

There were ValueErrors for the execution of the O2 runs with 1,000 chunks and for all runs with 15,000 chunks which were thrown in the confusion matrix generation and which led to predictions full of NaNs. Its origin lay in our resampling of the chunks, in which individual (very few) data points were missing and thus were filled in by Darts with NaN values by default. The following code cell only includes the final extraction of chunk IDs were values were missing.

Note: It does not matter which resampling method is investigated as they all are dealing with the same chunk IDs. We have randomly chosen the MEDIAN method.

In [None]:
from darts import TimeSeries
import pandas as pd

for n_chunks in [1000, 2000, 15000]:
    for parameter in ['hr', 'bp', 'o2']:
        resampled_chunks = pd.read_parquet(f'../../data/resampling/resample_output_{parameter}_first{n_chunks}.parquet',
                                           engine='pyarrow')

        # Extract relevant (= minimal length 13) chunks
        relevant_series = dict()

        for chunk_id in pd.unique(resampled_chunks.CHUNK_ID_FILLED_TH):
            current_series = resampled_chunks[resampled_chunks['CHUNK_ID_FILLED_TH'] == chunk_id]

            if len(current_series) > 12:
                relevant_series[chunk_id] = TimeSeries.from_dataframe(
                    df=current_series,
                    time_col='CHARTTIME',
                    value_cols=['VITAL_PARAMTER_VALUE_MEDIAN_RESAMPLING'],
                    freq='H')

        # Look for chunks with NaN values (missing values are filled by Darts per default)
        chunk_ids_with_nan = list()

        for chunk_id in relevant_series.keys():
            chunk_as_df = relevant_series[chunk_id].pd_dataframe()
            chunk_as_df.reset_index(level=0, inplace=True)
            chunk_as_df.columns = ['Time', 'Value']

            if chunk_as_df['Value'].isnull().values.any():
                chunk_ids_with_nan.append(chunk_id)

        print(f'Chunk IDs with missing values for {parameter.upper()} with {n_chunks} chunks: \n{chunk_ids_with_nan}\n')

## How long does an ICU stay last?

In [None]:
import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

plotdata = pd.DataFrame()
for parameter in ['hr', 'bp', 'o2']:
    # Read cleaned, resampled and chunked CHARTEVENTS
    resampled_chartsevents = pd.read_parquet(f'../../data/resampling/resample_output_{parameter}_first2000.parquet',
                                             engine='pyarrow')

    # Collect durations (in h) of chunks
    chunk_durations = list()
    for chunk_id in pd.unique(resampled_chartsevents.CHUNK_ID_FILLED_TH):
        current_chunk = resampled_chartsevents[resampled_chartsevents['CHUNK_ID_FILLED_TH'] == chunk_id]

        if len(current_chunk) > 12:
            current_starttime = current_chunk['CHARTTIME'].min()
            current_endtime = current_chunk['CHARTTIME'].max()
            current_duration_in_s = (current_endtime - current_starttime).total_seconds()

            # Convert duration to hours
            chunk_durations.append(divmod(current_duration_in_s, 3600)[0])

    # Convert to DataFrame
    chunk_durations_param = pd.DataFrame({'PARAMETER': [parameter.upper()] * len(chunk_durations),
                                          'DURATION_IN_H': chunk_durations})
    plotdata = pd.concat([plotdata, chunk_durations_param], axis=0, ignore_index=True)

# Visualize durations
sns.set_style('whitegrid')
sns.boxplot(data=plotdata,
            x='DURATION_IN_H',
            y='PARAMETER',
            palette = sns.color_palette('colorblind'))
plt.title('Chunk Durations')
plt.xlabel('Duration (in hours)')
plt.ylabel('Parameter')

## How many and how often do ICU stays appear in ICU?

In [None]:
import pandas as pd

for i, parameter in enumerate(['hr', 'bp', 'o2']):
    chunks = pd.read_parquet(f'../../data/resampling/resample_output_{parameter}_first2000.parquet',
                             engine='pyarrow')

    icustay_ids = list()
    for chunk_id in pd.unique(chunks['CHUNK_ID_FILLED_TH']):
        icustay_ids.append(chunk_id.split('_')[0])

    # Show how many different ICU stays exist per parameter
    print(parameter, len(set(icustay_ids)))

In [None]:
import matplotlib.pyplot as plt
from  matplotlib.ticker import FuncFormatter
import pandas as pd
import pickle5 as pickle
import seaborn as sns

for i, parameter in enumerate(['hr', 'bp', 'o2']):
    resampled_chartsevents = pd.read_parquet(f'../../data/resampling/resample_output_{parameter}_first2000.parquet',
                                             engine='pyarrow')

    # Collect non-unique ICU stay IDs
    icustay_ids = list()
    for chunk_id in pd.unique(resampled_chartsevents['CHUNK_ID_FILLED_TH']):
        icustay_ids.append(chunk_id.split('_')[0])

    print(f'{len(set(icustay_ids))} ICU stay IDs appear once in 2,000 {parameter.upper()} chunks')

    unique_icustay_ids_f = open(f'../../data/icustay_ids/icustay_ids_{parameter.upper()}.pickle', 'wb')
    pickle.dump(set(icustay_ids), unique_icustay_ids_f, protocol=pickle.HIGHEST_PROTOCOL)
    unique_icustay_ids_f.close()

    # Count how often each ICU stay ID appear
    icustay_id_counts = pd.DataFrame()
    for icustay_id in set(icustay_ids):
        icustay_id_counts = icustay_id_counts.append(
            {'ICUSTAY_ID' : icustay_id, 'Count' : icustay_ids.count(icustay_id)},
            ignore_index=True)

    print(icustay_id_counts.Count.describe())

    # Plot appearance count
    sns.set_style('whitegrid')
    plt.figure(i)
    ax = sns.countplot(data=icustay_id_counts,
                       x='Count',
                       color=sns.color_palette('colorblind')[0])
    ax.xaxis.set_major_formatter(FuncFormatter(lambda x, _: int(x)))
    plt.title(f'Number of Appearances of ICU Stay IDs in 2,000 {parameter.upper()} Chunks')
    plt.xlabel('Number of Appearances')
    plt.ylabel('Count')

## What are the patient characteristics?

In [None]:
import datetime
import pandas as pd
import pickle5 as pickle

# Create plot data with patient information of chunks
plotdata = pd.DataFrame(columns=['ICUSTAY_ID', 'PARAMETER', 'LEAVE_CHARTTIME', 'DATE_OF_LEAVE', 'SUBJECT_ID', 'GENDER',
                                 'EXPIRE_FLAG', 'DOB', 'DOD', 'AGE'])

# Add ICU stay IDs from our chunks
for parameter in ['bp', 'hr', 'o2']:
    # Read list with ICU stay IDs created above
    icustay_ids_f = open(f'../../data/icustay_ids/icustay_ids_{parameter.upper()}.pickle', 'rb')
    icustay_ids = pickle.load(icustay_ids_f)
    icustay_ids_f.close()

    plotdata = pd.concat([plotdata, pd.DataFrame(
        {'PARAMETER' : [parameter.upper()] * len(icustay_ids),
         'ICUSTAY_ID' : [int(float(i)) for i in icustay_ids]}
    )], axis=0, ignore_index=True)

# Add related subject IDs from CHARTEVENTS + chart times/ date at ICU leave
chartevents = pd.read_parquet('../../data/chartevents_subset.parquet', engine='pyarrow')
subject_ids, leave_charttimes = list(), list()

for index, row in plotdata.iterrows():
    subject_ids.append(chartevents[chartevents['ICUSTAY_ID'] == row['ICUSTAY_ID']]
                       ['SUBJECT_ID'].tolist()[0]) # get single subject ID
    leave_charttimes.append(chartevents[chartevents['ICUSTAY_ID'] == row['ICUSTAY_ID']].sort_values('CHARTTIME')
                           ['CHARTTIME'].tolist()[-1]) # get last chart time
plotdata['SUBJECT_ID'] = subject_ids
plotdata['LEAVE_CHARTTIME'] = leave_charttimes
plotdata['DATE_OF_LEAVE'] = pd.to_datetime(plotdata['LEAVE_CHARTTIME']).dt.date

# Read patient data
patients = pd.read_csv('../../data/mimic-iii-clinical-database-1.4/PATIENTS.csv',
                       parse_dates=['DOB', 'DOD', 'DOD_HOSP', 'DOD_SSN'],
                       dtype={
                           'ROW_ID': 'float64', # int according to specification
                           'SUBJECT_ID': 'float64', # int according to specification
                           'GENDER': 'object', # varchar(5) according to specification
                           'EXPIRE_FLAG': 'object' # varchar(5) according to specification
                       })

# Add patient columns (order is ensured by parameter and subjectID sorts)
relevant_subject_ids = dict()
genders, expire_flags, dobs, dods = list(), list(), list(), list()
plotdata_cleaned = pd.DataFrame()

for parameter in ['bp', 'hr', 'o2']:
    plotdata_param = plotdata[plotdata['PARAMETER'] == parameter.upper()]

    # Reduce patient data to plot data and vice versa
    patients_param = patients[patients['SUBJECT_ID'].isin(plotdata_param['SUBJECT_ID'].tolist())]
    plotdata_param = plotdata_param[plotdata_param['SUBJECT_ID'].isin(patients_param['SUBJECT_ID'].tolist())]\
        .drop_duplicates(subset='SUBJECT_ID', keep='last')
    plotdata_cleaned = pd.concat([plotdata_cleaned, plotdata_param])

    # Collect rows of cols per parameter
    patients_param = patients_param.sort_values('SUBJECT_ID')
    genders = genders + patients_param['GENDER'].tolist()
    expire_flags = expire_flags + patients_param['EXPIRE_FLAG'].tolist()
    dobs = dobs + patients_param['DOB'].tolist()
    dods = dods + patients_param['DOD'].tolist()

plotdata = plotdata_cleaned

# Sort plotdata according to list filling above
plotdata = plotdata.sort_values(['PARAMETER', 'SUBJECT_ID'])

plotdata['GENDER'] = genders
plotdata['EXPIRE_FLAG'] = expire_flags
plotdata['DOB'] = dobs
plotdata['DOD'] = dods

# Add age column (either age at death or age at ICU leave)
def calc_age(death_date, birth_date):
    diff = datetime.datetime.strptime(death_date, '%Y-%m-%d') - datetime.datetime.strptime(birth_date, '%Y-%m-%d')
    return int(float(diff.days) / 364.0)

dead_patients = plotdata[plotdata['EXPIRE_FLAG'] == '1']
dead_patients['AGE'] = [calc_age(death_date, birth_date) for death_date, birth_date in
                        zip(dead_patients['DOD'].astype(str), dead_patients['DOB'].astype(str))]

# Note: If the patient was alive at least 90 days post hospital discharge, DOD is null
# Source: https://mit-lcp.github.io/mimic-schema-spy/tables/patients.html
living_patients = plotdata[plotdata['EXPIRE_FLAG'] == '0']
living_patients['AGE'] = [calc_age(death_date, birth_date) for death_date, birth_date in
                          zip((living_patients['DATE_OF_LEAVE'] + datetime.timedelta(days=90)).astype(str),
                              living_patients['DOB'].astype(str))]

plotdata = pd.concat([dead_patients, living_patients], axis=0, ignore_index=True)

# Remove absurd ages (e.g. created because of very early date of birth)
plotdata = plotdata[plotdata['AGE'] < 150]

plotdata['PARAMETER_LABEL'] = plotdata['PARAMETER']\
    .replace(['HR', 'BP', 'O2'], ['$HR$', '$NBPs$', '$S_pO_2$'], regex=True)

plotdata.to_parquet(f'../../data/patient_info_of_2000chunks.parquet', engine='pyarrow')
print(plotdata.head())

In [None]:
import matplotlib.pyplot as plt
import seaborn as sns

sns.set_style('whitegrid')

# Visualize gender + dead/ alive distribution
for col_name in ['Gender', 'Expire_Flag']:
    plt.figure(figsize=(6, 4), dpi=72)
    sns.countplot(
        data=plotdata,
        x=col_name.upper(),
        hue='PARAMETER_LABEL',
        palette=sns.color_palette('colorblind'))
    plt.xlabel(col_name.replace('_', ' '))
    plt.ylabel('Count')
    plt.legend(title='Parameter:', bbox_to_anchor=(1.05, 1), loc=2, borderaxespad=0.)
    plt.savefig(f'../../plots/patient_analysis/chunk_patients_{col_name.lower()}.pdf', dpi=300, bbox_inches='tight')

In [None]:
# Visualize ages (at ICU leave) of living patients
fig, ax = plt.subplots(
    nrows=3,
    ncols=1,
    figsize=(6, 14),
    sharex=True,
    dpi=72
    )
#fig.suptitle('Ages of Patients (black bars for dead patients)', fontweight='bold', fontsize=22, y=1)

for i, parameter in enumerate(['hr', 'bp', 'o2']):
    g1 = sns.histplot(ax=ax[i],
                      data=plotdata, # blue are all patients
                      x='AGE',
                      kde=True,
                      palette=sns.color_palette('colorblind'),
                      bins=50)
    g2 = sns.histplot(ax=ax[i],
                      data=plotdata[plotdata['EXPIRE_FLAG'] == '0'], # black are dead patients
                      x='AGE',
                      kde=True,
                      color='black',
                      bins=50)
    ax[i].set_title(plotdata[plotdata['PARAMETER'] == parameter.upper()].iloc[0]['PARAMETER_LABEL'], fontweight='bold', fontsize=14)

    if i != 2:
        g1.set(xlabel=None)
        g2.set(xlabel=None)
    else:
        plt.xlabel('Age')

fig.tight_layout()
plt.savefig(f'../../plots/patient_analysis/chunk_patients_ages.pdf', dpi=300)

## Are the chunks non-stationary?

It is non-stationary if the properties of the time-series do depend on the time at which the series is observed (aka if they have a trend or seasonality). This can be checked by splitting the chunk data into partitions, and compare the means and variances of the partitions. If the differences are statistically significant, the time-series is likely non-stationary.

**Source:** Box, G. E. P., Jenkins, G. M., Reinsel, G. C., & Ljung, G. M. (2015). "Time series analysis: Forecasting and control (5th ed)". Hoboken, New Jersey: John Wiley & Sons.

**Result:** As assumed, the chunks are rather stationary.

In [None]:
import pandas as pd

for parameter in ['hr', 'bp', 'o2']:
    print(f'\nPARAMETER: {parameter.upper()}')
    chunks = pd.read_parquet(f'../../data/resampling/resample_output_{parameter}_first2000.parquet', engine='pyarrow')

    n_chunks = len(pd.unique(chunks['CHUNK_ID_FILLED_TH']))
    chunk_ids_first_partition = pd.unique(chunks['CHUNK_ID_FILLED_TH'])[:int(n_chunks/2)]
    chunk_ids_second_partition = pd.unique(chunks['CHUNK_ID_FILLED_TH'])[int(n_chunks/2):]

    first_partition = chunks[chunks['CHUNK_ID_FILLED_TH'].isin(chunk_ids_first_partition)]
    second_partition = chunks[chunks['CHUNK_ID_FILLED_TH'].isin(chunk_ids_second_partition)]

    print(f'Mean of First Partition: {first_partition["VITAL_PARAMTER_VALUE_MEDIAN_RESAMPLING"].mean()}')
    print(f'Mean of Second Partition: {second_partition["VITAL_PARAMTER_VALUE_MEDIAN_RESAMPLING"].mean()}')

    print(f'Variance of First Partition: {first_partition.var()["VITAL_PARAMTER_VALUE_MEDIAN_RESAMPLING"]}')
    print(f'Variance of Second Partition: {second_partition.var()["VITAL_PARAMTER_VALUE_MEDIAN_RESAMPLING"]}')