# Analysis of Results Produced by `mimic_prediction_rnn/prediction_1000_chunks.py`

## Imports

In [None]:
import matplotlib.pyplot as plt
import os
import pandas as pd
import pickle5 as pickle
import seaborn as sns

## Variables to Exchange

In [None]:
model_type = 'LSTM'
parameter = 'hr'

## Extract Relevant Chunk IDs and Its Predicted Values

In [None]:
# Extract predicted series
to_pred_series_f = open(f'../../data/darts/1000_chunks/{model_type}/{parameter}/02_non-scaled_pred_series.pickle', 'rb')
to_pred_series = pickle.load(to_pred_series_f)
to_pred_series_f.close()

print(f'#Chunks to predict: {len(list(to_pred_series.keys()))}')

# Collect pickle file names with final prediction per chunk
path_to_predicted_pickle_files = f'../../data/darts/1000_chunks/{model_type}/{parameter}'
prediction_filenames = list()

# Extract relevant pickle files
for file in os.listdir(path_to_predicted_pickle_files):
    if os.path.isfile(os.path.join(path_to_predicted_pickle_files, file)) and \
            file.startswith('05_non-scaled_prediction') and file.endswith('final.pickle'):
        prediction_filenames.append(file)

# Create dict {chunkID : prediction_timeseries}
predicted_series = dict()

for file in prediction_filenames:
    current_pred_series_f = open(f'{path_to_predicted_pickle_files}/{file}', 'rb')
    current_pred_series = pickle.load(current_pred_series_f)
    current_pred_series_f.close()

    # Extract chunk ID from filename
    current_chunk = '_'.join(file[len('05_non-scaled_prediction_'):].split("_", 3)[:3]).replace('%3A', ':')

    predicted_series[current_chunk] = current_pred_series

# Shortly check if extracted chunk ID are correct
if set(to_pred_series.keys()) != set(predicted_series.keys()):
    print('There is a mismatch between the expected and extracted chunk IDs !!!')

## Add Original Alarm Triggering Booleans

In [None]:
# Extract original series
original_resampled = pd.read_parquet(f'../../data/resampling/resample_output_{parameter}_first1000.parquet', engine='pyarrow')

# TODO: resolve warning

# Convert original series into dict for confusion matrix comparisons, aka {chunkID : info_columns}
chunks = dict()

for chunk_id in to_pred_series.keys():
    # Filter for current chunk and sort by time
    current_sorted_chunk = original_resampled[original_resampled['CHUNK_ID_FILLED_TH'] == chunk_id].sort_values('CHARTTIME')

    # Remove 12 non-predicted values
    current_sorted_chunk = current_sorted_chunk[-len(predicted_series[chunk_id]):].reset_index()

    # Convert info into dict
    chunks[chunk_id] = current_sorted_chunk[['CHARTTIME', 'VITAL_PARAMTER_VALUE_MEDIAN_RESAMPLING',
                                             'THRESHOLD_VALUE_HIGH', 'THRESHOLD_VALUE_LOW']]

    # Add boolean indicating if high alarms were triggered
    chunks[chunk_id]['HIGH_ALARM_TRIGGERED'] = chunks[chunk_id].apply(
        lambda row: row.VITAL_PARAMTER_VALUE_MEDIAN_RESAMPLING > row.THRESHOLD_VALUE_HIGH, axis=1)

    # Add boolean indicating if low alarms were triggered
    chunks[chunk_id]['LOW_ALARM_TRIGGERED'] = chunks[chunk_id].apply(
        lambda row: row.VITAL_PARAMTER_VALUE_MEDIAN_RESAMPLING < row.THRESHOLD_VALUE_LOW, axis=1)

## Add Predicted Values and Alarm Triggering Booleans

In [None]:
# TODO: resolve warning

for chunk_id in chunks.keys():
    # Add predicted vital parameter value
    chunks[chunk_id]['VITAL_PARAMTER_VALUE_PREDICTED'] = predicted_series[chunk_id].Value

    # Add boolean indicating if predicted value exceeds high alarm
    chunks[chunk_id]['HIGH_ALARM_TRIGGERED_PREDICTED'] = chunks[chunk_id].apply(
        lambda row: row.VITAL_PARAMTER_VALUE_PREDICTED > row.THRESHOLD_VALUE_HIGH, axis=1)

    # Add boolean indicating if predicted value falls below high alarm
    chunks[chunk_id]['LOW_ALARM_TRIGGERED_PREDICTED'] = chunks[chunk_id].apply(
        lambda row: row.VITAL_PARAMTER_VALUE_PREDICTED < row.THRESHOLD_VALUE_LOW, axis=1)

## Fill Confusion Matrices

In [None]:
confusion_matrix_high = pd.DataFrame(columns=['CHUNK_ID', 'TP', 'FN', 'FP', 'TN'])
confusion_matrix_low = pd.DataFrame(columns=['CHUNK_ID', 'TP', 'FN', 'FP', 'TN'])

for chunk_id in chunks.keys():

    # Get indices where booleans are false/ true
    high_triggered = set(chunks[chunk_id].index[chunks[chunk_id]['HIGH_ALARM_TRIGGERED']])
    high_triggered_pred = set(chunks[chunk_id].index[chunks[chunk_id]['HIGH_ALARM_TRIGGERED_PREDICTED']])
    high_not_triggered = set(chunks[chunk_id].index[~chunks[chunk_id]['HIGH_ALARM_TRIGGERED']])
    high_not_triggered_pred = set(chunks[chunk_id].index[~chunks[chunk_id]['HIGH_ALARM_TRIGGERED_PREDICTED']])

    # Fill confusion matrix for low threshold analysis
    confusion_matrix_high = confusion_matrix_high.append(
        {
            'CHUNK_ID': chunk_id,
            # Following 4 numbers look at how many indices are shared
            'TP': len(high_triggered.intersection(high_triggered_pred)),
            'FN': len(high_triggered.intersection(high_not_triggered_pred)),
            'FP': len(high_not_triggered.intersection(high_triggered_pred)),
            'TN': len(high_not_triggered.intersection(high_not_triggered_pred))
        },
        ignore_index=True)

    # Add accuracy value (later to avoid storing TP etc. two times)
    confusion_matrix_high['ACCURACY'] = confusion_matrix_high.apply(
        lambda row: ((row.TP + row.TN) / (row.TP + row.FN + row.FP + row.TN)), axis=1)

    # Get indices where booleans are false/ true
    low_triggered = set(chunks[chunk_id].index[chunks[chunk_id]['LOW_ALARM_TRIGGERED']])
    low_triggered_pred = set(chunks[chunk_id].index[chunks[chunk_id]['LOW_ALARM_TRIGGERED_PREDICTED']])
    low_not_triggered = set(chunks[chunk_id].index[~chunks[chunk_id]['LOW_ALARM_TRIGGERED']])
    low_not_triggered_pred = set(chunks[chunk_id].index[~chunks[chunk_id]['LOW_ALARM_TRIGGERED_PREDICTED']])

    # Fill confusion matrix for low threshold analysis
    confusion_matrix_low = confusion_matrix_low.append(
        {
            'CHUNK_ID': chunk_id,
            # Following 4 numbers look at how many indices are shared
            'TP': len(low_triggered.intersection(low_triggered_pred)),
            'FN': len(low_triggered.intersection(low_not_triggered_pred)),
            'FP': len(low_not_triggered.intersection(low_triggered_pred)),
            'TN': len(low_not_triggered.intersection(low_not_triggered_pred))
        },
        ignore_index=True)

    # Add accuracy value (later to avoid storing TP etc. two times)
    confusion_matrix_low['ACCURACY'] = confusion_matrix_low.apply(
        lambda row: ((row.TP + row.TN) / (row.TP + row.FN + row.FP + row.TN)), axis=1)

#print(confusion_matrix_low[confusion_matrix_low['CHUNK_ID'] == '200033.0_220045.0_2198-08-07 19:53:00'])
print(f'Confusion Matrix (HIGH): \n{confusion_matrix_high}\n')
print(f'Confusion Matrix (LOW): \n{confusion_matrix_low}')

# Correlation Between Chunk Length and Chunk Accuracy

In [None]:
# TODO: Maybe include chunk length into chunks dict (instead of creating 3 lists)

# Fill lists for plotting
chunk_lengths = list()
chunk_accuracies_high = list()
chunk_accuracies_low = list()

for chunk_id in chunks:
    chunk_lengths.append(len(chunks[chunk_id]))
    chunk_accuracies_high.append(confusion_matrix_high[confusion_matrix_high['CHUNK_ID'] == chunk_id].ACCURACY.to_list()[0])
    chunk_accuracies_low.append(confusion_matrix_low[confusion_matrix_low['CHUNK_ID'] == chunk_id].ACCURACY.to_list()[0])

# Define background color, subplots and suptitle
sns.set_style('whitegrid')
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(8, 4))
fig.suptitle('Correlation of Chunk Length and Chunk Accuracy', fontsize=14)

# Add left plot (high threshold analysis)
ax1.plot(chunk_lengths, chunk_accuracies_high, 'o', color=sns.color_palette('colorblind')[0])
ax1.set_title('Accuracy Regarding High Thresholds', fontsize=10)
ax1.set_xlabel('Chunk Length', fontsize=8)
ax1.set_ylabel('Chunk Accuracy', fontsize=8)
ax1.set_ylim(bottom=0, top=1.1)

# Add right plot (low threshold analysis)
ax2.plot(chunk_lengths, chunk_accuracies_low, 'o', color=sns.color_palette('colorblind')[1])
ax2.set_title('Accuracy Regarding Low Thresholds', fontsize=10)
ax2.set_xlabel('Chunk Length', fontsize=8)
ax2.set_ylabel('Chunk Accuracy', fontsize=8)
ax2.set_ylim(bottom=0, top=1.1)

# Improve layout and save figure
fig.tight_layout()
#fig.show()
fig.savefig(f'../../plots/darts/1000_chunks/correlation_chunk_length_and_accuracy_{model_type}_{parameter}.png', dpi=1200)