<h1><center>CS598 Deep Learning for Healthcare Spring 2023<br>Paper Reproduction Project</center></h1>

<h3><center>Gilberto Ramirez and Jay Kawkani<br><span style="font-family:monospace;">{ger6, kakwani2}@illinois.edu</span><br><font color="lightgrey">Group ID: 27 | Paper ID: 181</font></center></h3>

In this project, we aim to reproduce the paper [*Learning Task for Multitask Learning: Heterogeneous Patient Populations in the ICU* by (Suresh et al, 2018)](https://arxiv.org/abs/1806.02878). In this paper, the authors propose a novel two-step pipeline to predict in-hospital mortality across patient populations with different characteristics. The first step of the pipeline divides patients into relevant non-overlapping cohorts in an unsupervised way using a long short-term memory (LSTM) autoencoder followed by a Gaussian Mixture Model (GMM). The second step of the pipeline predicts in-hospital mortality for each patient cohort identified in the previous step using an LSTM based multi-task learning model where every cohort is considered a different task.
The paper claims that by applying this pipeline, the multi-task learning model can leverage shared knowledge across the distinct patient groups identified and it can work effectively since the groups were obtained using a data-driven method rather than relying in domain knowledge or auxiliary labels.

## Table of Contents

1. [Data](#section-1)
2. [Methods](#section-2)

## <a class="anchor" id="section-1">1. Data</a>

This paper uses the publicly available [MIMIC-III database](https://www.nature.com/articles/sdata201635) which contains clinical data in a critical care setting. After reviewing the paper in detail, we decided to use [MIMIC-Extract](https://arxiv.org/abs/1907.08322), an open source pipeline by (Wang et al., 2020) for transforming the raw EHR data into usable Pandas dataframes containing hourly time series of vitals and laboratory measurements after performing unit conversion, outlier handling, and aggregation of semantically similar features.

Unfortunately, the MIMIC-Extract pipeline misses two features the [paper code](https://github.com/mit-caml/multitask-patients) makes use of:
* `timecmo_chart` which indicates the timestamp of a patient when it has been declared in CMO (Comfort Measures Only) state. This feature comes from a MIMIC-III concept table called `code_status`.
* `sapsii` which contains the SAPS (Simplified Acute Physiology Score) II. This feature comes from another MIMIC-III concept table called `sapsii`.

As a result, there are three data files needed to run this notebook:
* `all_hourly_data.h5`, an HDF file resulting from running the MIMIC-Extract pipeline which is publicly available in GCP using [this link](https://console.cloud.google.com/storage/browser/mimic_extract) and referenced in the [MIMIC-Extract github repo](https://github.com/MLforHealth/MIMIC_Extract).
* `code_status.csv`, a CSV file holding the MIMIC concept table `CODE_STATUS` that can be generated following the instructions in [this link within the MIT-LCP github repo](https://github.com/MIT-LCP/mimic-code/tree/main/mimic-iii/concepts#generating-the-concepts-in-postgresql).
* `sapsii.csv`, a CSV file holding the MIMIC concept table `SAPSII` that can be generated following the instructions in [this link within the MIT-LCP github repo](https://github.com/MIT-LCP/mimic-code/tree/main/mimic-iii/concepts#generating-the-concepts-in-postgresql).

The functions used in this notebook assume the three files are in the folder `../data/` by default. However, location can be defined using arguments to the functions that process the data.

All code needed to replicate the paper is in [our github repo](https://github.com/ger6-illini/dl4h-sp23-team27-project) inside a Python module called `mtl_patients`.

The first function from that module we will start using is `get_summaries()`. This function provides three summaries in four dataframes which, in return order, are:
* A summary providing some statistics of all patients broken by careunit.
* A summary providing some statistics of all patients broken by SAPS-II score quartile.
* A summary providing some statistics of the 29 distinct physiological measurements used in the paper.

In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import pandas as pd

import sys
pathname = "../code/"
if pathname not in sys.path:
    sys.path.append("../code/")

from mtl_patients import get_summaries

In [2]:
pat_summ_by_cu_df, pat_summ_by_sapsiiq_df, vitals_labs_summ_df = get_summaries()

Let's now display the summaries one at a time.

### 1.1. Data summary by patients in each intensive care unit (ICU)

In [3]:
pat_summ_by_cu_df

Unnamed: 0_level_0,N,n,Class Imbalance,Age (Mean),Gender (Male)
Careunit,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1
CCU,5193,790,0.152,83.31,0.58
CSRU,7050,223,0.032,69.54,0.67
MICU,12207,2674,0.219,78.21,0.51
SICU,5520,829,0.15,73.49,0.51
TSICU,4502,583,0.129,67.33,0.61
Overall,34472,5099,0.148,75.03,0.57


In the previous summary, patients were broken in groups where each group is one of the five careunits where patients were first admitted:
* CCU: Coronary Care Unit
* CSRU: Cardiac Surgery Recovery Unit
* MICU: Medical Intensive Care Unit
* SICU: Surgical Intensive Care Unit
* TSICU: Trauma Surgical Intensive Care Unit

In addition, an overall group was also added. The statistics provided by the summary are:
* `N`: The number of samples (patients) in the group.
* `n`: The number of samples (patients) where meeting the in-hospital mortality criteria defined in the paper: patient died or had a note of "Do Not Resuscitate" (DNR) or had a note of "Comfort Measures Only" (CMO).
* `Class Imbalance`: Ratio of patients meeting the in-hospital mortality criteria defined in the paper, i.e., $\dfrac{\text{N}}{\text{n}}$.
* `Age (Mean)`: Mean age of patients for each group in years.
* `Gender (Male)`: Ratio of patients that are males.

This summary was prepared to match the Table 1 in the original paper. There are differences between both that can be attributed to the way how data was preprocessed by MIMIC-Extract when compared to the preprocessing done by the authors back in 2018, before MIMIC-Extract became available.

### 1.2. Data summary by patients in each SAPS-II score quartile

In [7]:
pat_summ_by_sapsiiq_df

Unnamed: 0_level_0,N,n,Class Imbalance,Age (Mean),Gender (Male),SAPS-II (Min),SAPS-II (Mean),SAPS-II (Max)
SAPS-II Quartile,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1
0,7449,115,0.015,45.5,0.61,0,16.56,22
1,10322,669,0.065,68.84,0.58,23,27.73,32
2,8360,1274,0.152,86.7,0.55,33,36.72,41
3,8341,3041,0.365,97.36,0.53,42,52.62,118
Overall,34472,5099,0.148,75.03,0.57,0,33.52,118


In the previous summary, patients were broken based on the quartile of the SAPS-II score assigned to them. As it can be seen, the two quartiles have the ranges $[0, 22], [23, 32], [33, 41], [42, 118] $. This was included in the authors code but not in the paper. It seems the class imbalance might have been the primary reason. As it is evident from the summary, most of the patients are in quartile $3$ since they are in an ICU and is expected their values are on the high side.

### 1.3. Data summary for physiological measurements

In [5]:
vitals_labs_summary_df

Unnamed: 0_level_0,min,avg,max,std,N,pres.
Vital/Lab Measurement,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1
anion gap,5.0,13.72,50.0,3.99,183732,0.0835
bicarbonate,0.0,24.23,53.0,4.74,192632,0.0875
blood urea nitrogen,0.0,26.21,250.0,21.75,194596,0.0884
chloride,50.0,105.22,175.0,6.31,211525,0.0961
creatinine,0.1,1.39,46.6,1.48,195429,0.0888
diastolic blood pressure,0.0,60.89,307.0,14.13,1908674,0.8672
fraction inspired oxygen,0.21,0.53,1.0,0.19,98315,0.0447
glascow coma scale total,3.0,12.49,15.0,3.59,377787,0.1716
glucose,33.0,140.49,1591.0,57.22,512585,0.2329
heart rate,0.0,84.97,300.0,17.27,1971748,0.8959


In the previous summary, all vitals and lab measurements selected in the paper (29 in total) are listed with relevant statistics associated to it:
* `min` representing the minimum of the measurement observed in the vitals/labs.
* `avg` representing the average of the measurement observed in the vitals/labs.
* `max` representing the maximum of the measurement observed in the vitals/labs.
* `std` representing the standard deviation of the measurement observed in the vitals/labs.
* `N` representing the number of non `NaN` samples for the specific vital/lab measurement.
* `pres.` representing the portion of all possible hours across all patients, admissions, and ICU stays where at least one of the 104 vitals/labs measurements in the original MIMIC-Extract pipeline was taken.

All these measurements are based on the `vitals_labs_mean` dataframe in the MIMIC-Extract pipeline which provides average of vitals/labs on a per hour basis for each patient after going into an ICU.

## <a class="anchor" id="section-2">2. Methods & Results</a>

### 2.1. Identifying Meaningful Patient Cohorts

In [20]:
from mtl_patients import prepare_data

In [24]:
from mtl_patients import stratified_split

In [35]:
from keras.models import Model
from keras.layers import Input, LSTM, RepeatVector
from keras.optimizers import Adam
from keras.callbacks import EarlyStopping
from sklearn.mixture import GaussianMixture

In [36]:
train_val_random_seed = 0
embedding_dim = 50
epochs = 100
learning_rate = 0.0001
num_clusters = 3
gmm_tol = 0.0001

In [16]:
X, Y, cohort_careunits, cohort_sapsii_quartile, subject_ids = prepare_data(cutoff_hours=24, gap_hours=12)

In [27]:
X.shape

(32537, 24, 232)

In [31]:
X_train.shape

(22775, 24, 232)

In [26]:
# Do train/validation/test split using careunits as the cohort classifier
X_train, X_val, X_test, y_train, y_val, y_test, cohorts_train, cohorts_val, cohorts_test = \
    stratified_split(X, Y, cohort_careunits, train_val_random_seed=train_val_random_seed)

In [34]:
num_timesteps = X_train.shape[1]  # number of timesteps (T), e.g., 24 hours
num_features = X_train.shape[2]   # number of features (F), e.g., 232
embedding_dim = embedding_dim     # hidden representation dimension

# 1) take a temporal sequence of 1D vectors of `num_features` (F)
inputs = Input(shape=(num_timesteps, num_features))
# 2) encode it using an LSTM into a 1D vector with `embedding_dim` elements
encoded = LSTM(embedding_dim)(inputs)
# 3) repeat the embedding from the encoder T times so we can feed the result
#    to a decoder and the reconstructed representation of the input
decoded = RepeatVector(num_timesteps)(encoded)
# 4) decode the result using an LSTM of size `num_features` to get the
#    reconstructed representation of the input
decoded = LSTM(num_features, return_sequences=True)(decoded)

# the LSTM autoencoder model takes the input, encode it to an embedding,
# decode it from the embeddeing and provides a reconstructed output
lstm_autoencoder = Model(inputs, decoded)

# the encoder model is the one that is trained once the LSTM autoencoder
# model is trained, and will be used to get the embeddings
encoder = Model(inputs, encoded)

lstm_autoencoder.compile(optimizer=Adam(lr=learning_rate), loss='mse')
early_stopping = EarlyStopping(monitor='val_loss', patience=3)

# fit (train) the LSTM autoencoder model
print("Training LSTM autoencoder started...")
lstm_autoencoder.fit(X_train, X_train,
    epochs=epochs,
    batch_size=128,
    shuffle=True,
    callbacks=[early_stopping],
    validation_data=(X_val, X_val))
print("Training LSTM autoencoder trained!")

  super().__init__(name, **kwargs)


Training LSTM autoencoder started...
Epoch 1/100
Epoch 2/100
Epoch 3/100
Epoch 4/100
Epoch 5/100
Epoch 6/100
Epoch 7/100
Epoch 8/100
Epoch 9/100
Epoch 10/100
Epoch 11/100
Epoch 12/100
Epoch 13/100
Epoch 14/100
Epoch 15/100
Epoch 16/100
Epoch 17/100
Epoch 18/100
Epoch 19/100
Epoch 20/100
Epoch 21/100
Epoch 22/100
Epoch 23/100
Epoch 24/100
Epoch 25/100
Epoch 26/100
Epoch 27/100
Epoch 28/100
Epoch 29/100
Epoch 30/100
Epoch 31/100
Epoch 32/100
Epoch 33/100
Epoch 34/100
Epoch 35/100
Epoch 36/100
Epoch 37/100
Epoch 38/100
Epoch 39/100
Epoch 40/100
Epoch 41/100
Epoch 42/100
Epoch 43/100
Epoch 44/100
Epoch 45/100
Epoch 46/100
Epoch 47/100
Epoch 48/100
Epoch 49/100
Epoch 50/100
Epoch 51/100
Epoch 52/100
Epoch 53/100
Epoch 54/100
Epoch 55/100
Epoch 56/100
Epoch 57/100
Epoch 58/100
Epoch 59/100
Epoch 60/100
Epoch 61/100
Epoch 62/100
Epoch 63/100
Epoch 64/100
Epoch 65/100
Epoch 66/100
Epoch 67/100
Epoch 68/100
Epoch 69/100
Epoch 70/100
Epoch 71/100
Epoch 72/100
Epoch 73/100
Epoch 74/100
Epoch 75/1

Epoch 80/100
Epoch 81/100
Epoch 82/100
Epoch 83/100
Epoch 84/100
Epoch 85/100
Epoch 86/100
Epoch 87/100
Epoch 88/100
Epoch 89/100
Epoch 90/100
Epoch 91/100
Epoch 92/100
Epoch 93/100
Epoch 94/100
Epoch 95/100
Epoch 96/100
Epoch 97/100
Epoch 98/100
Epoch 99/100
Epoch 100/100
Training LSTM autoencoder trained!


In [37]:
# now that the LSTM autoencoder model is trained
# the corresponding encoder is trained as well
# and we can use it to encode X
embeddings_X_train = encoder.predict(X_train)
embeddings_X = encoder.predict(X)



In [39]:
# With the embeddings now we can fit a Gaussian Mixture Model
print("Training Gaussian Mixture Model...")
gmm = GaussianMixture(n_components=num_clusters, tol=gmm_tol, verbose=True)
gmm.fit(embeddings_X_train)

# Finally, we can calculate the cluster membership
cohort_unsupervised = gmm.predict(embeddings_X)

Training Gaussian Mixture Model...
Initialization 0
  Iteration 10
  Iteration 20
  Iteration 30
  Iteration 40
Initialization converged: True


In [41]:
np.save('../data/unsupervised_clusters.npy', cohort_unsupervised)