# Discharging patients from the Intensive Care Unit

## Description

At the Intensive Care Unit, patients who are in critical, life-threatening conditions receive continuous care. The purpose of the ICU recording is to bring the patient to a stable condition, after which they can go back to the normal ward. The moment of discharge is important: discharging a patient too early can lead to complications and being readmitted to the ICU with an increased risk of death. Keeping patients too long is also undesirable, capacity is limited and patients can get complications as a result of being at the ICU for too long. 

Use machine learning to help IC doctors decide when a patient can be discharged.

## Data

Enormous amounts of data are continuously collected at the IC. Physiological values such as heart rate and blood pressure are recorded per minute, and, in addition, patient characteristics, clinical observations and laboratory results such as blood values are known.

Here are the details of 3 (artificial) datasets with IC data:
* age.csv
    * file containing the patients age
    * fields: *pat_id*, *age*
* admission.csv
    * file containing information on when a patient is admitted to and discharged from the IC
    * fields: *pat_id*, *date_admission*, *date_discharge*
* signal.csv
    * file containing (artificial) information on 3 physiological parameters on the patient
    * fields: *pat_id*, *day*, *hour*, *parameter*, *value*

## Assignment

Build an algorithm that can help the doctors at the IC decide who can be discharged using the signal data of the patient. The algorithm should be able to predict which patients have a high risk of being readmitted if they were to be discharged.

## Implementation

### Load and inspect data

In [None]:
import random

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from datetime import datetime
from collections import Counter

from sklearn import metrics
from sklearn.preprocessing import Normalizer
from sklearn.neural_network import MLPClassifier

%matplotlib inline

In [None]:
# function to load data from csv file
def load_data(input_file, sep=';', fields=None):
    """Load patient data from csv file"""
    return pd.read_csv(input_file, header=0, sep=sep, usecols=fields)

# functions for time handling
def convert_time_in(s):
    """From string to datetime (start of day)"""
    s += " 00:00:00"
    return datetime.strptime(s, "%Y-%m-%d %H:%M:%S")

def convert_time_out(s):
    """From string to datetime (end of day)"""
    s += " 23:59:59"
    return datetime.strptime(s, "%Y-%m-%d %H:%M:%S")

def convert_time_mid(s):
    """From string to datetime (given hour)"""
    return datetime.strptime(s, "%Y-%m-%d %H")

# set up file names
age_file = '/src/data/age.csv'
adm_file = '/src/data/admission.csv'
sig_file = '/src/data/signal.csv'

* Load and inspect the 'age' table

In [None]:
age_table = load_data(age_file, fields=['pat_id', 'age'])
# set age field as integer
age_table['age'] = pd.to_numeric(age_table['age'], downcast='integer')

age_table.head()

* Load and inspect the 'admission' table

In [None]:
admission_table = load_data(
    adm_file, 
    fields=['pat_id', 'date_admission', 'date_discharge'])

# convert admission date to datetime format (start of day; hour=00:00)
admission_table['date_admission'] = admission_table['date_admission'].map(convert_time_in)

# convert discharge date to datetime format (end of day; hour=23:59)
admission_table['date_discharge'] = admission_table['date_discharge'].map(convert_time_out)

admission_table.head()

* Load and inspect the 'signal' table

In [None]:
signal_table = load_data(
    sig_file, 
    fields=['pat_id', 'day', 'hour', 'parameter', 'value'])

# combine day and hour columns
signal_table['date_recording'] = signal_table['day'] + " " + signal_table['hour'].map(str)
signal_table['date_recording'] = signal_table['date_recording'].map(convert_time_mid)

# remove day and hour columns
signal_table = signal_table.drop("day", axis=1)
signal_table = signal_table.drop("hour", axis=1)

signal_table.head()

* Check the amount of information in each table

In [None]:
print("Age table contains {} entries".format(len(age_table)))
print("Admission table contains {} entries".format(len(admission_table)))
print("Signal table contains {} entries".format(len(signal_table)))

* Check the patient ids

In [None]:
def compare_ids(a, b):
    """Compare the ids in two series"""
    
    m = 0
    s = set()
    for pid in a:
        if pid in b:
            m += 1
        else:
            s.add(pid)
    return m, sorted(list(s))

age_ids = age_table.pat_id.unique()
adm_ids = admission_table.pat_id.unique()
sig_ids = signal_table.pat_id.unique()

m1, s1 = compare_ids(age_ids, adm_ids)
_, s2 = compare_ids(adm_ids, age_ids)
print("Age and admission tables have {} common patients".format(m1))
print("\t- the {} patients without admissions are: {}". format(len(s1), s1))
print("\t- the {} admitted patients without known age are: {}". format(len(s2), s2))

m1, s1 = compare_ids(age_ids, sig_ids)
_, s2 = compare_ids(sig_ids, age_ids)
print("\nAge and signal tables have {} common patients".format(m1))
print("\t - the {} patients without recorded signals are: {}". format(len(s1), s1))
print("\t - the {} recorded patients without known age are: {}". format(len(s2), s2))

m1, s1 = compare_ids(adm_ids, sig_ids)
_, s2 = compare_ids(sig_ids, adm_ids)
print("\nAdmission and signal tables have {} common patients".format(m1))
print("\t- the {} admitted patients without recorded signals are: {}". format(len(s1), s1))
print("\t- the {} recorded patients without known admissions are: {}". format(len(s2), s2))

In [None]:
# keep only the entries that have common ids in all 3 tables
ids = set()
for pid in age_ids:
    if pid in adm_ids and pid in sig_ids:
        ids.add(pid)

age_table = age_table[age_table.pat_id.isin(ids)]
admission_table = admission_table[admission_table.pat_id.isin(ids)]
signal_table = signal_table[signal_table.pat_id.isin(ids)]

print("Age table contains {} entries".format(len(age_table)))
print("Admission table contains {} entries".format(len(admission_table)))
print("Signal table contains {} entries".format(len(signal_table)))

* Check the number of re-admissions

In [None]:
def detect_re_admissions():
    """Detect re-admissions of patients in the admission table"""   
    
    duplicates = admission_table.groupby('pat_id').pat_id.count()
    duplicates = Counter(duplicates.tolist())
    return duplicates

counts = detect_re_admissions()
for n_adm, n_pat in counts.items():
    print("{:4} patient(s) have {} logged admission(s)".format(
        int(n_pat), n_adm))

* Inspect the signal recordings on each admission

In [None]:
def display_data(pid, age, admissions, signals):
    """Pretty print the patient's information"""
    
    print('Patient data')
    print("  * id:  {}".format(pid))
    print("  * age: {}".format(age))
    
    print("  * history of admissions")
    for date_in, date_out in admissions:
        print("    - from {} to {}".format(date_in, date_out))
            
    for i in range(len(signals)):
        print("  * signal recordings for admission #{}".format(i+1))
        for signal_type, recordings in signals[i].items():
            rounded = [round(x, 2) for x in recordings[-10:]]
            print("    - {}: (...) {}".format(signal_type, rounded)) 
    
def inspect_patient(pid):
    """Inspect available data on a given patient id"""
    
    # get the patient's age
    age = age_table[age_table.pat_id == pid].age.item()
    
    # get the patient's admission dates    
    admissions = []
    selection = admission_table[admission_table.pat_id == pid]
    for row in selection.itertuples():
        admissions.append((row.date_admission, row.date_discharge))

    # get the patient's signals for each admission
    signal_types = signal_table.parameter.unique()
    signals = []
    for date_in, date_out in admissions:
        match = {}
        for signal in signal_types:
            match[signal] = []
            selection = signal_table.query(
                "(pat_id == @pid) & "\
                "(parameter == @signal) & "\
                "(@date_in <= date_recording <= @date_out)")
            for row in selection.itertuples():
                match[signal].append(row.value)
        signals.append(match)
    
    return age, admissions, signals

pid = 470
age, admissions, signals = inspect_patient(pid)
display_data(pid, age, admissions, signals)

* Show mean values for blood_pressure, respiration rate and temperature

In [None]:
def get_average_signal(signal_type):
    """Get the average value of the recorded signal type"""
    selection = signal_table[signal_table.parameter == signal_type].value
    return np.mean(selection)
    
signal_means={
    'blood_pressure': get_average_signal('blood_pressure'),
    'respiration_rate': get_average_signal('respiration_rate'),
    'temperature': get_average_signal('temperature'),
}

print(signal_means)

### Feature engineering

* pre-processing
    * on the *age* table
        * set age field as integer value
    * on the *admission* table
        * convert *date_admission* field to datetime format (replace missing hour with 00:00 to consider a full day)
        * convert *date_discharge* field to datetime format (replace missing hour with 23:59 to consider a full day)
        * add a new boolean column *high\_risk* based on the event of a succedding quick readmission
    * on the *signal* table
        * merge the *day* and *hour* columns under a single column and convert it to a datetime format
    * on all tables
        * keep only the entries of the 1425 common patients
        
* merge data sources and generate **features for classification**
    * target output variable *y*
        * the *high_risk* field denoting a patient's iminent re-admission to the ICU
        * for the moment, consider as high risk a patient being re-admitted to the ICU under a week of previous discharge
        * there are 64 patients with more than one admission to the ICU in the dataset
        * only 29 out of the 64 admissions validate the one-week condition
    * explanatory input variables *x* (20 features)
        * age
        * time spent in the ICU (number of days)
        * statistics computed on the recordings of blood pressure, respiration rate and temperature (on each patient's admission to the ICU)
            * difference between last value and first value
            * difference between maximum value and minimum value
            * difference between the maximum values of the second half and first half
            * difference between the minimum values of the second half and first half
            * difference between the mean values of the second half and first half
            * difference between the standard deviation values of the second half and first half
    * admissions with missing signals for blood pressure, respiration rate or temperature each generate zero-valued features
    * complete list of features (normalized)
        * *age*
        * *period*
        * *blood_pressure_difference_last_first*
        * *blood_pressure_difference_max_min*
        * *blood_pressure_difference_max2_max1*
        * *blood_pressure_difference_min2_min1*
        * *blood_pressure_difference_mean2_mean1*
        * *blood_pressure_differnece_std2_std1*
        * *respiration_rate_difference_last_first*
        * *respiration_rate_difference_max_min*
        * *respiration_rate_difference_max2_max1*
        * *respiration_rate_difference_min2_min1*
        * *respiration_rate_difference_mean2_mean1*
        * *respiration_rate_difference_std2_std1*
        * *temperature_difference_last_first*
        * *temperature_difference_max_min*
        * *temperature_difference_max2_max1*
        * *temperature_difference_min2_min1*
        * *temperature_difference_mean2_mean1*
        * *temperature_difference_std2_std1*

* create features dataset

In [None]:
# the features on the signal sequence
def extract_features(data):
    """
    Define the features that describe the signal sequence
        - difference between last value and first value
        - difference between maximum value and minimum value
        - difference between the maximum values of the 2nd half and 1st half
        - difference between the minimum values of the 2nd half and 1st half
        - difference between the mean values of the 2nd half and 1st half
        - difference between the standard deviation values of the 2nd half and 1st half
    """
    
    if len(data) >= 2:
        # split data into first and second halves
        middle = len(data) // 2 
        first_half = data[:middle]
        second_half = data[middle:]
        
        return [
            data[-1] - data[0],
            np.max(data) - np.min(data),
            np.max(second_half) - np.max(first_half),
            np.min(second_half) - np.min(first_half),
            np.mean(second_half) - np.mean(first_half),
            np.std(second_half) - np.std(first_half)
        ]
    
    # consider a sequence of average values
    # => zero-valued differences
    
    mean_signal = [0.0 for _ in range(6)]    
    return mean_signal

In [None]:
# generate the input-target variables
def generate_dataset(n_days=7, n_feat=20):
    """Create a new dataset for the classification process"""
    
    # initialize the features list and the label list
    x = []
    y = []
        
    # extract the list of unique signal types
    signal_types = sorted(signal_table.parameter.unique())
        
    # for each patient
    for age_row in age_table.itertuples():
        # extract id and age
        pid = age_row.pat_id
        age = age_row.age

        # extract ICU admissions and sort them by admission date
        adm_selection = admission_table[admission_table.pat_id == pid]
        adm_selection = adm_selection.sort_values('date_admission')

        admissions = []
        for adm_row in adm_selection.itertuples():
            admissions.append((adm_row.date_admission, adm_row.date_discharge))

        # check for high risks
        for i in range(len(admissions)):  
            date_in = admissions[i][0]
            date_out = admissions[i][1]

            # the number of days spent in ICU
            period = (date_out - date_in).days

            # check the high risk of the current admission  
            high_risk = 0
            if i < len(admissions) - 1:
                next_date_in = admissions[i + 1][0]
                # if the period between two admissions is under one week
                if (next_date_in - date_in).days <= n_days:
                    high_risk = 1

            feat = [age, period]

            # get the patient's signals for current admission
            for signal in signal_types:
                sig_selection = signal_table.query(
                    "(pat_id == @pid) & "\
                    "(parameter == @signal) & "\
                    "(@date_in <= date_recording <= @date_out)")
                
                values = []
                for row in sig_selection.itertuples():
                    values.append(row.value)
                    
                feat.extend(extract_features(values))

            # add features and label to dataset
            if len(feat) == n_feat:
                x.append(feat)
                y.append(high_risk)

    return x, y
        
x, y = generate_dataset()

In [None]:
# show info on generated data
n = len(x)
npoz = sum(y)
nneg = len(y) - npoz

print(
    "Generated a dataset of {} entries, "\
    "with {} positive entries "\
    "and {} negative entries".format(
        n, npoz, nneg))

# choose a random entry
rand_index = random.randint(0, n - 1)
sample_feat = x[rand_index]
sample_label = y[rand_index]
print("\nExample of entry:\n\t- y = {}\n\t- x = {}".format(
    sample_label, sample_feat))

* Split data into train (75%) and test (25%) subsets

In [None]:
# get in index of positive and negative samples
pos_index = [i for i in range(len(y)) if y[i]]
neg_index = [i for i in range(len(y)) if not y[i]]

train_split = 0.75
pos_index_75 = random.sample(pos_index, int(train_split * len(pos_index)))

# define the indexes for train data: 75% of data
train_index = pos_index_75.copy()
train_index.extend(random.sample(neg_index, int(train_split * len(neg_index))))

# handle the unbalanced training data set
# - over-sample the minority class
ndup = int(train_split * (nneg - npoz))
for _ in range(ndup):
    train_index.append(random.choice(pos_index_75))

# define the indexes for test data: remaining data
test_index = [i for i in pos_index if i not in train_index]
test_index.extend(i for i in neg_index if i not in train_index)

# final shuffle of positive-negative indexes
random.shuffle(train_index)
random.shuffle(test_index)

# get the samples for training and testing
train_x = [x[i] for i in train_index]
train_y = [y[i] for i in train_index]

test_x = [x[i] for i in test_index]
test_y = [y[i] for i in test_index]

print("Generated a train set of {} instances ({} positive)".format(
    len(train_x), sum(train_y)))
print("Generated a test set of {} instances ({} positive)".format(
    len(test_x), sum(test_y)))

* Normalize features to unit norm

In [None]:
# fit the normalizer on train data
transformer = Normalizer().fit(train_x)

# apply the normalizer on train data
train_x = transformer.transform(train_x)

# apply the normalizer on test data
test_x  = transformer.transform(test_x)

In [None]:
print("Example of normalized entry: {}".format(
    transformer.transform([sample_feat])))

### Binary classification problem solved with a Neural Network classifier

* Train a NeuralNetwork classification model

In [None]:
model = MLPClassifier(
    hidden_layer_sizes=(128, 128),
    learning_rate_init=0.001,
    alpha=0.01,
    max_iter=100,
    batch_size=252,
    verbose=1)

model.fit(train_x, train_y)

* Evaluate the classification performance
    * accuracy
    * confusion matrix
    * precision & recall
    * ROC curve

In [None]:
def plot_roc(fpr, tpr, roc_auc):
    """Draw ROC curve"""
    
    plt.plot(
        fpr, 
        tpr,
        color='red',
        label="ROC curve (area = {:.2%}".format(roc_auc))    
    plt.plot([0, 1], [0, 1], color='blue', linestyle='--')
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate')
    plt.ylabel('True Positive Rate')
    plt.title('Receiver operating characteristic')
    plt.legend(loc="lower right")
    plt.show()
    
def display_performance(real, predicted):
    """Display performance metrics"""
    
    acc = metrics.accuracy_score(real, predicted)
    cm = metrics.confusion_matrix(real, predicted, labels=[1, 0])
    report = metrics.classification_report(real, predicted)
    
    fpr, tpr, thresholds = metrics.roc_curve(real, predicted)
    auc = metrics.roc_auc_score(real, predicted)
    
    print("Accuracy:\n{:.2%}\n".format(acc))
    print("ConfusionMatrix:\n{}\n".format(cm))
    print("Report:\n{}\n".format(report))
    print("AUC:\n{:.2%}".format(auc))
    
    plot_roc(fpr, tpr, auc)

* Check overfitting

In [None]:
# predict labels on TRAIN data
predicted = model.predict(train_x)
display_performance(train_y, predicted)

* Check generalization performance

In [None]:
# predict labels on TEST data
predicted = model.predict(test_x)
display_performance(test_y, predicted)

## Conclusions

* bad performance (only a 54% AUC)
* expect better results with more data (especially with more positive examples)
* better domain knowledge could help investigate other features / approaches

## Perspectives

* try the approach on a bigger dataset
* try other features
* try feature selection
* try other solution for handling missing values
* try other solution for handling an imbalanced dataset
* try other classification algorithms
* tune hyper-parameters