In [0]:
import torch
from torch.utils import data

import torch.nn as nn
import torch.nn.functional as F

import pandas as pd
%matplotlib inline
from matplotlib import pyplot as plt

import numpy as np
import pickle

from google.colab import auth

device = "cuda" if torch.cuda.is_available() else "cpu"

In [0]:
!wget -r -N -c -np --user kyleliu --ask-password https://physionet.org/files/picdb/1.0.0/

In [0]:
# Read Data into DF

admissions = pd.read_csv('physionet.org/files/picdb/1.0.0/ADMISSIONS.csv.gz', compression='gzip')
chartevents = pd.read_csv('physionet.org/files/picdb/1.0.0/CHARTEVENTS.csv.gz', compression='gzip')
diagnoses_icd = pd.read_csv('physionet.org/files/picdb/1.0.0/DIAGNOSES_ICD.csv.gz', compression='gzip')
d_icd_diagnoses = pd.read_csv('physionet.org/files/picdb/1.0.0/D_ICD_DIAGNOSES.csv.gz', compression='gzip')
d_items = pd.read_csv('physionet.org/files/picdb/1.0.0/D_ITEMS.csv.gz', compression='gzip')
d_labitems = pd.read_csv('physionet.org/files/picdb/1.0.0/D_LABITEMS.csv.gz', compression='gzip')
emr_symptoms = pd.read_csv('physionet.org/files/picdb/1.0.0/EMR_SYMPTOMS.csv.gz', compression='gzip')
icu_stays = pd.read_csv('physionet.org/files/picdb/1.0.0/ICUSTAYS.csv.gz', compression='gzip')
input_events = pd.read_csv('physionet.org/files/picdb/1.0.0/INPUTEVENTS.csv.gz', compression='gzip')
lab_events = pd.read_csv('physionet.org/files/picdb/1.0.0/LABEVENTS.csv.gz', compression='gzip')
patients = pd.read_csv('physionet.org/files/picdb/1.0.0/PATIENTS.csv.gz', compression='gzip')
prescriptions = pd.read_csv('physionet.org/files/picdb/1.0.0/PRESCRIPTIONS.csv.gz', compression='gzip')
surgery_vital_signs = pd.read_csv('physionet.org/files/picdb/1.0.0/SURGERY_VITAL_SIGNS.csv.gz', compression='gzip')

In [0]:
# Easier to use: 

item_dict = dict() 
for _, row in d_items.iterrows(): 
  item_dict[row.ITEMID] = row.LABEL

lab_item_dict = dict()
for _, row in d_labitems.iterrows(): 
  lab_item_dict[row.ITEMID] = row.LABEL

ICD_CN_TO_ICD = dict() 
for _, row in d_icd_diagnoses.iterrows(): 
  ICD_CN_TO_ICD[row.ICD10_CODE_CN] = row.ICD10_CODE 


Here we include only the first admission of each patient.

In [0]:
# Clean: Include only the first admission

admissions = admissions.sort_values(by = ['ADMITTIME'])
chartevents = chartevents.sort_values(by = ['CHARTTIME'])
lab_events = lab_events.sort_values(by = ['CHARTTIME'])

admits_to_keep = []
seen_patients = set()

for _, row in admissions.iterrows(): 
  if row.SUBJECT_ID not in seen_patients: 
    admits_to_keep.append(row.HADM_ID)
    seen_patients.add(row.SUBJECT_ID)

In [0]:
def remove_admits(df): 
  return df[df['HADM_ID'].isin(admits_to_keep)]

admissions = remove_admits(admissions)
chartevents = remove_admits(chartevents)
diagnoses_icd = remove_admits(diagnoses_icd)
emr_symptoms = remove_admits(emr_symptoms)
icu_stays = remove_admits(icu_stays)
input_events = remove_admits(input_events)
lab_events = remove_admits(lab_events)
prescriptions = remove_admits(prescriptions)
surgery_vital_signs = remove_admits(surgery_vital_signs)


Helper functions to parse admit times.

In [0]:
from datetime import date, timedelta, time, datetime

def to_datetime(x): 
  li = x.split()
  my_date = li[0].split("-")
  my_time = li[1].split(":")

  ret = datetime(int(my_date[0]), int(my_date[1]), int(my_date[2]), int(my_time[0]), int(my_time[1]), int(my_time[2]))
  
  return ret

age_at_admission = dict()  
birth_date = dict()
admit_date = dict() 
for _, row in patients.iterrows(): 
  birth_date[row.SUBJECT_ID] = to_datetime(row.DOB)

for _, row in admissions.iterrows(): 
  admit_date[row.SUBJECT_ID] = to_datetime(row.ADMITTIME)
  age_at_admission[row.SUBJECT_ID] = to_datetime(row.ADMITTIME) - birth_date[row.SUBJECT_ID]

In [0]:
# Time since admission (hours)
def normalize_time(patient_id, x): 
  delta = to_datetime(x) - admit_date[patient_id]
  return delta.total_seconds() / 3600.0 

In [0]:
patient_set = set([p for p in patients.SUBJECT_ID])

In [0]:
chartevents['HOURS_IN'] = chartevents.apply(lambda row: normalize_time(row.SUBJECT_ID, row.CHARTTIME), axis=1)
lab_events['HOURS_IN'] = lab_events.apply(lambda row: normalize_time(row.SUBJECT_ID, row.CHARTTIME), axis=1)
surgery_vital_signs['HOURS_IN'] = surgery_vital_signs.apply(lambda row: normalize_time(row.SUBJECT_ID, row.MONITORTIME), axis=1)

In [0]:
def get_feature_name(idx): 
  if idx < (len(lab_feats)): 
    return lab_item_dict[lab_feats[idx]]
  elif idx < (len(lab_feats) + len(chart_feats)): 
    return item_dict[chart_feats[idx - len(lab_feats)]]
  elif idx < (len(lab_feats) + len(chart_feats) + len(surgery_feats)): 
    return item_dict[surgery_feats[idx - len(lab_feats) - len(chart_feats)]]
  elif idx < (len(lab_feats) + len(chart_feats) + len(surgery_feats) + 2):
    return 'gender'
  else: 
    return 'age'

def get_feature_name_flattened(idx): 
  hours_in = idx // (len(lab_feats) + len(chart_feats) + len(surgery_feats))

  idx -= hours_in * (len(lab_feats) + len(chart_feats) + len(surgery_feats))

  if hours_in == WINDOW_SIZE: 
    if idx < 2: 
      return 'gender'
    else: 
      return 'age'
  else: 
    if idx < (len(lab_feats)): 
      return lab_item_dict[lab_feats[idx]]
    elif idx < (len(lab_feats) + len(chart_feats)): 
      return item_dict[chart_feats[idx - len(lab_feats)]]
    elif idx < (len(lab_feats) + len(chart_feats) + len(surgery_feats)): 
      return item_dict[surgery_feats[idx - len(lab_feats) - len(chart_feats)]]

  

In [0]:
import math 
## Feature Set

## Chart Features
chart_feats = ['1001', '1002', '1003', '1004', '1006', '1007', '1008', '1009', '1010', '1011', '1012', '1013', '1014', '1015', '1016']


# Surgery Vital Signs
surgery_feats = surgery_vital_signs['ITEMID'].value_counts().index.tolist() 

lab_feats = [5225, 
             5097, 
             5141, 
             5129, 
             5257, 
             5114,
             5113,
             5115,
             5132,
             5136,
             5226,
             5230,
             5218,
             5224,
             5212,
             5033,
             5041,
             5223,
             5215,
             5174,
             5111,
             6317,
             5094,
             5492,
             5002,
             5075,
             5237,
             5249,
             5235,
             5239,
             5227,
             5026,
             5031,
             5024,
             6085
             ]



We use these to index into the tensors that follow (i.e. chart_X[patient_index_of[subject_id]] is what you want, not chart_X[subject_id]. Similar for item_id's

In [0]:
# More Helper Dicts
chart_index_of = dict() 
for i in range(len(chart_feats)): 
  chart_index_of[chart_feats[i]] = i
  
lab_index_of = dict() 
for i in range(len(lab_feats)): 
  lab_index_of[lab_feats[i]] = i

surgery_index_of = dict() 
for i in range(len(surgery_feats)): 
  surgery_index_of[surgery_feats[i]] = i


print(chart_index_of)
print(lab_index_of)
print(surgery_index_of)

patient_index_of = dict() 
cc = 0
for p in patient_set: 
  patient_index_of[p] = cc 
  cc += 1
  
  

In [0]:
GAP_TIME          = 6  # In hours
WINDOW_SIZE       = 24 # Data collection window: In hours
# Label has to be first satisfied after GAP_TIME + WINDOW_SIZE

# Generate per-hour aggregates (w/in the window).

chart_X = np.zeros((len(patient_set), WINDOW_SIZE, len(chart_feats)))
chart_Xcnt = np.zeros((len(patient_set), WINDOW_SIZE, len(chart_feats)))
lab_X = np.zeros((len(patient_set), WINDOW_SIZE, len(lab_feats)))
lab_Xcnt = np.zeros((len(patient_set), WINDOW_SIZE, len(lab_feats)))
surgery_X = np.zeros((len(patient_set), WINDOW_SIZE, len(surgery_feats)))
surgery_Xcnt = np.zeros((len(patient_set), WINDOW_SIZE, len(surgery_feats)))

# Remove negative chart times
subjects_to_remove = set() 

def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False

for _, row in lab_events[lab_events['HOURS_IN'] < WINDOW_SIZE][lab_events['ITEMID'].isin(lab_feats)].iterrows():
  if row.HOURS_IN < 0: 
    subjects_to_remove.add(row.SUBJECT_ID)
  elif is_number(row.VALUE): 
    lab_X[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][lab_index_of[row.ITEMID]] += row.VALUENUM
    lab_Xcnt[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][lab_index_of[row.ITEMID]] += 1 

for _, row in surgery_vital_signs[surgery_vital_signs['HOURS_IN'] < WINDOW_SIZE][surgery_vital_signs['ITEMID'].isin(surgery_feats)].iterrows():
  if row.HOURS_IN < 0: 
    subjects_to_remove.add(row.SUBJECT_ID)
  elif is_number(row.VALUE): 
    surgery_X[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][surgery_index_of[row.ITEMID]] += row.VALUE
    surgery_Xcnt[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][surgery_index_of[row.ITEMID]] += 1 

for _, row in chartevents[chartevents['HOURS_IN'] < WINDOW_SIZE][chartevents['ITEMID'].isin(chart_feats)].iterrows():
  if row.HOURS_IN < 0: 
    subjects_to_remove.add(row.SUBJECT_ID)
    continue 
  elif is_number(row.VALUE): 
    chart_X[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][chart_index_of[row.ITEMID]] += row.VALUENUM 
    chart_Xcnt[patient_index_of[row.SUBJECT_ID]][int(row.HOURS_IN)][chart_index_of[row.ITEMID]] += 1 

for _, row in admissions.iterrows(): 
  if normalize_time(row.SUBJECT_ID, row.DISCHTIME) < (GAP_TIME + WINDOW_SIZE): 
    subjects_to_remove.add(row.SUBJECT_ID)

Here I have simple Forward/Backward Imputation implemented. If time, we can try to implement the various other ones mentioned by https://www.nature.com/articles/s41598-018-24271-9 

global_mean is the mean of each feature over all time points and all patients. If a patient has no occurances of a feature at any time point, it's replaced by the global mean. Otherwise, we propagate values forward/backward to replace missing values. 

In [0]:
# Missing Data Imputation

# Forward/Backward Imputation

# Compute Global means first. 

global_chart_mean = np.zeros(len(chart_feats))
global_chart_num = np.zeros(len(chart_feats))
global_lab_mean = np.zeros(len(lab_feats)) 
global_lab_num = np.zeros(len(lab_feats)) 
global_surgery_mean = np.zeros(len(surgery_feats))
global_surgery_num = np.zeros(len(surgery_feats))

for i in range(len(patient_set)): 
  for j in range(WINDOW_SIZE): 
    for k in range(len(chart_feats)): 
      global_chart_mean[k] += chart_X[i][j][k]
      global_chart_num[k] += chart_Xcnt[i][j][k] 
    for k in range(len(lab_feats)): 
      global_lab_mean[k] += lab_X[i][j][k]
      global_lab_num[k] += lab_Xcnt[i][j][k] 
    for k in range(len(surgery_feats)): 
      global_surgery_mean[k] += surgery_X[i][j][k] 
      global_surgery_num[k] += surgery_Xcnt[i][j][k] 

for k in range(len( chart_feats)): 
  global_chart_mean[k] = global_chart_mean[k] / global_chart_num[k]

for k in range(len(lab_feats)): 
  global_lab_mean[k] = global_lab_mean[k] / global_lab_num[k]

for k in range(len(surgery_feats)): 
  global_surgery_mean[k] = global_surgery_mean[k] / global_surgery_num[k]


def forward_backward_impute(feats, global_mean): 
  # INPUTS: 
  # Feats -- (WINDOW_SIZE, num_feats)
  # glboal_mean -- (num_feats)
  # OUTPUTS: 
  # ret -- (WINDOW_SIZE, num_feats) (imputed)
  ret = feats 
  for j in range(feats.shape[1]):
    for i in range(1, WINDOW_SIZE): 
      if ret[i][j] <= 0: 
        ret[i][j] = ret[i-1][j]
    for i in range(WINDOW_SIZE-2, -1, -1): 
      if ret[i][j] <= 0: 
        ret[i][j] = ret[i+1][j]
    for i in range(WINDOW_SIZE): 
      if ret[i][j] <= 0: 
        ret[i][j] = global_mean[j]
  return ret 










In [0]:
# Set up X, Y 


# Set up labels

patient_set = list(patient_set)

mort_icu = dict() 
for _, row in patients.iterrows(): 
  if row.SUBJECT_ID in patient_set: 
    mort_icu[row.SUBJECT_ID] = row.EXPIRE_FLAG 

gender_one_hot = np.zeros((len(patient_set), 2))
age_vec = np.zeros((len(patient_set), 1))
for _, row in patients.iterrows(): 
  if row.SUBJECT_ID in patient_set: 
    age_vec[patient_index_of[row.SUBJECT_ID]][0] = (age_at_admission[row.SUBJECT_ID].total_seconds() / 3600.0)
    if row.GENDER == 'M': 
      gender_one_hot[patient_index_of[row.SUBJECT_ID]][0] = 1
    else: 
      gender_one_hot[patient_index_of[row.SUBJECT_ID]][1] = 1

static_vec = np.concatenate((gender_one_hot, age_vec), axis = 1)
# [num_patients, 3]

chart_vec = chart_X / (chart_Xcnt + (chart_Xcnt == 0))
lab_vec = lab_X / (lab_Xcnt + (lab_Xcnt == 0))
surgery_vec = surgery_X / (surgery_Xcnt + (surgery_Xcnt == 0))

for i in range(len(patient_set)): 
  chart_vec[i] = forward_backward_impute(chart_vec[i], global_chart_mean)
  lab_vec[i] = forward_backward_impute(lab_vec[i], global_lab_mean)
  surgery_vec[i] = forward_backward_impute(surgery_vec[i],  global_surgery_mean)

time_vec = np.concatenate((lab_vec, chart_vec, surgery_vec), axis=2)
# time_vec [num_patients, window_size, num_lab_features + num_chart_features + num_vital_features]

# concatenate this with static_vec [num_patients, 3]

In [0]:
def get_mask(removed_subjects): 

  mask = [True for p in patient_set]
  for p in removed_subjects:
    mask[patient_index_of[p]] = False

  return mask 

def setup_data(task, model):
  my_subjects_to_remove = subjects_to_remove
  if task == 'Sepsis Prediction': 

    # Protect against labels found in the range [0, WINDOW_SIZE + GAP]. Remove these patients.

    for _, row in admissions.iterrows(): 
      if row.SUBJECT_ID in SEPTIC: 
        if SEPTIC[row.SUBJECT_ID] < (GAP_TIME + WINDOW_SIZE): 
          my_subjects_to_remove.add(row.SUBJECT_ID)

    # Masks

    mask = get_mask(my_subjects_to_remove)

    # Labels 

    labels = np.zeros(len(patient_set)) 
    for i in range(len(patient_set)): 
      if patient_set[i] in SEPTIC: 
        labels[i] = 1
      else: 
        labels[i] = 0
      
    # Covariates
    if model in ['LR', 'RF']: # Linear models
      covars = np.concatenate((np.reshape((time_vec), (time_vec.shape[0], time_vec.shape[1] * time_vec.shape[2])), static_vec), axis = 1)
      
      return covars[mask, ...], labels[mask, ...]

    else: # Time series models

      # expands labels to (num_patients, window_size) from (num_patients)
      labels_ts = torch.from_numpy(labels[mask, ...])
      labels_ts = labels_ts.unsqueeze(1).expand((time_vec[mask, ...].shape[0], time_vec[mask, ...].shape[1]))

      time_ts = torch.from_numpy(time_vec).float()
      static_ts = torch.from_numpy(static_vec).float().unsqueeze(1).expand((time_ts.shape[0], time_ts.shape[1], static_vec.shape[1]))
      covars_ts = torch.cat((time_ts, static_ts), dim=2)
      covars_ts = covars_ts[mask, ...]

      return covars_ts, labels_ts
    
  elif task == 'Mortality Prediction': 
    # No need to protect against labels in the range [0, Window + Gap] (all labels are found at discharge time).

    # Masks 

    mask = get_mask(my_subjects_to_remove)

    # Labels 

    labels = np.zeros(len(patient_set)) 
    for i in range(len(patient_set)): 
      if mort_icu[patient_set[i]] == 1:
        labels[i] = 1
      else: 
        labels[i] = 0
      
    # Covariates
    if model in ['LR', 'RF']: # Linear models
      covars = np.concatenate((np.reshape((time_vec), (time_vec.shape[0], time_vec.shape[1] * time_vec.shape[2])), static_vec), axis = 1)
      
      return covars[mask, ...], labels[mask, ...]
    else: # Time series models
      labels_ts = torch.from_numpy(labels[mask, ...]).float()
      labels_ts = labels_ts.unsqueeze(1).expand((time_vec[mask, ...].shape[0], time_vec[mask, ...].shape[1]))

      time_ts = torch.from_numpy(time_vec).float()
      static_ts = torch.from_numpy(static_vec).float().unsqueeze(1).expand((time_ts.shape[0], time_ts.shape[1], static_vec.shape[1]))
      covars_ts = torch.cat((time_ts, static_ts), dim=2)
      covars_ts = covars_ts[mask, ...]

      return covars_ts, labels_ts

  else: 
    return 0, 0

In [0]:
from sklearn.metrics import roc_curve
def plot_roc(title, labels, probs): 
  fpr, tpr, thresholds = roc_curve(labels, probs) 
  plt.figure()
  plt.plot(fpr, tpr, label=title)
  plt.plot([0, 1], [0, 1],'r--')
  plt.xlim([0.0, 1.0])
  plt.ylim([0.0, 1.05])
  plt.xlabel('1 - Specificity')
  plt.ylabel('Sensitivity')
  plt.title('ROC')
  plt.legend(loc="lower right")
  plt.show()

RF Helper Funcs

In [0]:

# X [num_patients, WINDOW_SIZE, num_features] (for LSTM)
# X [num_patients, WINDOW_SIZE * num_features] (for RF) 
# Y [num_patients, WINDOW_SIZE] (for LSTM) 
# Y [num_patients] (for RF)

from sklearn.model_selection import KFold
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

def train_rf(task, imbalanced=False, n_estimators=100, bootstrap=True, max_features='sqrt'): 
  X, Y = setup_data(task = task, model = 'RF')

  x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=1)
  x_test_orig = x_test

  my_scaler = StandardScaler()

  x_train = my_scaler.fit_transform(x_train.reshape(-1, x_train.shape[-1])).reshape(x_train.shape)
  x_test = my_scaler.transform(x_test.reshape(-1, x_test.shape[-1])).reshape(x_test.shape)
  
  if imbalanced: 
    x_train, y_train = oversample(x_train, y_train)

  model = RandomForestClassifier(n_estimators=n_estimators, 
                             bootstrap = bootstrap,
                             max_features = max_features)
  model.fit(x_train, y_train)

  return model, x_test, y_test, x_test_orig

def evaluate_rf(model, x_test, y_test, mask): 

  x_test = x_test[mask, ...]
  y_test = y_test[mask, ...]

  rf_predictions = model.predict(x_test)
  rf_probs = model.predict_proba(x_test)[:, 1]

  auc = roc_auc_score(y_test, rf_probs)
  acc = np.sum(rf_predictions == y_test) / len(y_test)

  return auc, acc, rf_predictions, rf_probs 

def run_task_rf(task): 
  if task in ['Mortality Prediction', 'Sepsis Prediction']: 
    model, x_test, y_test, x_test_orig = train_rf(task, imbalanced=True)
  else: 
    model, x_test, y_test, x_test_orig = train_rf(task, imbalanced=False)

  all_auc = []
  all_acc = []
  all_probs = []
  all_preds = []
  all_labels = []
  
  cohorts = ['Total', '0 - 2 Month', '2 Month - 2 Years', '2 Years - 5 Years', '5 Years - 12 Years']
  threshs = [-1, 60 * 24, 2 * 365 * 24, 5 * 365 * 24, 12 *  365 * 24]

  mask = [True for p in range(len(x_test))]
  for i in range(len(threshs)): 
    if i == 0: 
      mask = [True for p in range(len(x_test))]
    else: 
      for p in range(len(x_test)): 
        mask[p] = True
        age = x_test_orig[p][-1]
        if (age <= threshs[i-1]) or (age > threshs[i]): 
          mask[p] = False
  
    auc, acc, rf_preds, rf_probs = evaluate_rf(model, x_test, y_test, mask)
    all_auc.append(auc)
    all_acc.append(acc)
    all_probs.append(rf_probs)
    all_preds.append(rf_preds)
    all_labels.append(y_test[mask, ...])

  return cohorts, all_auc, all_acc, all_probs, all_preds, all_labels



In [0]:
cohorts, all_auc, all_acc, all_probs, all_preds, all_labels = run_task_rf('Mortality Prediction')

In [0]:
print(cohorts) 
print(all_auc)
print(all_acc)
print([len(all_labels[i]) for i in range(len(all_labels))])
print([np.sum(all_labels[i]) for i in range(len(all_labels))])

#['Total', '0 - 1 Week', '<1 Week - 1 Month', '<1 Month - 1 Year', '1 Year - 5 Years', '5 Years - 12 Years', '12 Years - 18 Years']
#[0.8171612386874347, 0.6792912408061897, 0.813731722822632, 0.8320398009950248, 0.8785932925206281, 0.7720345640219952, 0.8766666666666666]
#[0.9056764831412719, 0.9102564102564102, 0.9025974025974026, 0.8835616438356164, 0.9221854304635762, 0.9067796610169492, 0.941747572815534]
#[2343, 390, 154, 730, 604, 354, 103]
#[148.0, 29.0, 11.0, 60.0, 26.0, 19.0, 3.0]

plt.title('RF ROC Curves')
for i in range(len(cohorts)): 
  fpr, tpr, thresholds = roc_curve(all_labels[i], all_probs[i])
  plt.plot(fpr, tpr, label='%s ROC (area = %0.2f)' % (cohorts[i], all_auc[i]))

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
plt.legend(loc="lower right")





In [0]:
class LSTM_Classifier(nn.Module):
  def __init__(self, input_size, hidden_size, num_layers=1, dropout=0., bidirectional=False):
    super(LSTM_Classifier, self).__init__()

    self.input_size = input_size
    self.hidden_size = hidden_size
    self.num_layers = num_layers 
    self.bidirectional = bidirectional
    self.dropout = dropout

    self.rnn = nn.LSTM(input_size, hidden_size, num_layers=num_layers, batch_first=True,
                      dropout=dropout, bidirectional=bidirectional)
    self.out = nn.Linear(hidden_size + hidden_size * int(bidirectional), 1)

  def forward(self, input):
    # Input is (1, seq_len, input_size)
    rnn_out, _ = self.rnn(input)
    # rnn_out is (1, seq_len, directions * hidden_size)
    # output is (1, seq_len, 1)
    return self.out(rnn_out)

def rnn_train_one_sample(model, criterion, rnn_optimizer, sent_tensor, tag_tensor, alpha = 0.5, clip=None):

    # sent_tensor is (Num Hours, Num feats)
    # tag_tensor is (Num Hours)

    model.zero_grad() 

    outputs = model(sent_tensor.unsqueeze(0)).squeeze(2).squeeze(0)

    # loss = criterion(outputs, tag_tensor) * alpha + criterion(outputs[-1], tag_tensor[-1]) * (1.0-alpha)
    loss = criterion(outputs[-1], tag_tensor[-1]) 

    loss.backward()

    if clip != None: 
      torch.nn.utils.clip_grad_norm(model.parameters(), max_norm=clip)

    rnn_optimizer.step()

    return outputs, loss.item()


In [0]:
import time
import math
import sklearn
from sklearn.metrics import precision_recall_fscore_support

def timeSince(since):
    now = time.time()
    s = now - since
    m = math.floor(s / 60)
    s -= m * 60
    return '%dm %ds' % (m, s)


def evaluate_result(true_tag_list, predicted_tag_list, probs):
  return np.mean(true_tag_list.numpy() == predicted_tag_list), roc_auc_score(true_tag_list, probs)

# Make prediction for one sentence.
def rnn_predict_one_sent(model, sent_tensor):
  
    outputs = model(sent_tensor.unsqueeze(0)).squeeze(2).squeeze(0)
    prob = torch.sigmoid(outputs[-1])

    predicted_tag_id = 0
    if prob > 0.5: 
      predicted_tag_id = 1
    
    return predicted_tag_id, prob.item()


def evaluate_rnn(model, x_test, y_test, mask = None): 
  if mask == None: 
    mask = [True for i in range(len(x_test))]
  x_test = torch.from_numpy(x_test).float()
  x_test = x_test[mask, ...]
  y_test = y_test[mask, ...]

  model.eval()
  predicted_tags = []
  probs = []

  for i in range(len(x_test)): 
    sent_tensor = x_test[i]
    sent_tensor = sent_tensor.to(device)
    predicted_tag_id, prob = rnn_predict_one_sent(model, sent_tensor)
    predicted_tags.append(predicted_tag_id)
    probs.append(prob)

  acc, auc = evaluate_result(y_test[:, -1], predicted_tags, probs)

  return auc, acc, predicted_tags, probs 

def train_model(model, criterion, optimizer, X_train, Y_train, X_test, Y_test, n_epochs=5, print_every=1000, plot_every=50, learning_rate=1e-3, alpha = 0.5, clip=None): 

  iter_count = 0

  current_loss = 0
  current_norm = 0
  all_losses = []
  all_norms = []

  start = time.time()

  model.train()
  for epoch_i in range(n_epochs):

    for i in range(X_train.shape[0]): 
        sent_tensor = torch.tensor(X_train[i]).float()
        tag_tensor = Y_train[i]

        sent_tensor = sent_tensor.to(device)
        tag_tensor = tag_tensor.to(device)
  
        output, loss = rnn_train_one_sample(model, criterion, optimizer, sent_tensor, tag_tensor, alpha=alpha, clip=clip)
        current_loss += loss

        if iter_count % print_every == 0:
            print('%d %s %.4f' % (iter_count, timeSince(start), current_loss / print_every))
            current_loss = 0

        iter_count += 1

    auc, acc, _, _ = evaluate_rnn(model, X_test, Y_test)
    print("Epoch ", epoch_i, " ACC of ", acc, " AUC of ", auc)
  return all_losses, all_norms

def plot_losses(losses): 
  plt.figure()
  plt.title('Losses vs Iterations')
  plt.plot(losses)
  plt.show()

def plot_norms(norms): 
  plt.figure()
  plt.title('Norms vs Iterations')
  plt.plot(norms)
  plt.show()

Helpers with evaluation. We should tune these hyperparameters. In particular I set n_epochs to 2 because I just wanted to know that it works kinda :p

In [0]:
# Evaluation

def train_rnn(task, imbalanced=False, 
              n_epochs=10, 
              rnn_clip = 1.0, 
              rnn_hidden_size = 32, 
              rnn_num_layers = 2, 
              learning_rate = 1e-4, 
              rnn_dropout = 0.5, 
              rnn_alpha = 0.5, 
              weight_decay = 1e-4,
              rnn_bidirectional = False): 

  X, Y = setup_data(task = task, model = 'LSTM')
  
  x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=1)
  x_test_orig = x_test

  scaler = StandardScaler()

  x_train = scaler.fit_transform(x_train.reshape(-1, x_train.shape[-1])).reshape(x_train.shape)
  x_test = scaler.transform(x_test.reshape(-1, x_test.shape[-1])).reshape(x_test.shape)

  if imbalanced: 
    x_train = torch.from_numpy(x_train).float()
    x_train = torch.reshape(x_train, (x_train.shape[0], x_train.shape[1] * x_train.shape[2]))
    y_train = y_train[:, -1]

    x_train, y_train = oversample(x_train, y_train)

    x_train = torch.from_numpy(x_train).float()
    x_train = x_train.reshape(x_train.shape[0], WINDOW_SIZE, (int)(x_train.shape[1] / WINDOW_SIZE))
    y_train = torch.from_numpy(y_train).float()
    y_train = y_train.unsqueeze(1).expand((y_train.shape[0], WINDOW_SIZE))

  num_feats = x_train.shape[2]

  rnn_model = LSTM_Classifier(input_size=num_feats, hidden_size=rnn_hidden_size, num_layers=rnn_num_layers, dropout=rnn_dropout, bidirectional=rnn_bidirectional)
  criterion = nn.BCEWithLogitsLoss()
  rnn_optimizer = torch.optim.Adam(rnn_model.parameters(), lr=learning_rate, weight_decay=weight_decay)

  losses, norms = train_model(rnn_model, criterion, rnn_optimizer, x_train, y_train, x_test, y_test, n_epochs=n_epochs, alpha=rnn_alpha, clip=rnn_clip)

  return rnn_model, x_test, y_test, x_test_orig

def run_task_lstm(task): 
  if task in ['Sepsis Prediction']: 
    model, x_test, y_test, x_test_orig = train_rnn(task, imbalanced=True, n_epochs=40)
  else: 
    model, x_test, y_test, x_test_orig = train_rnn(task, imbalanced=False, n_epochs=40)

  all_auc = []
  all_acc = []
  all_probs = []
  all_preds = []
  all_labels = []

  cohorts = ['Total', '0 - 2 Month', '2 Month - 2 Years', '2 Years - 5 Years', '5 Years - 12 Years']
  threshs = [-1, 60 * 24, 2 * 365 * 24, 5 * 365 * 24, 12 *  365 * 24]

  mask = [True for p in range(len(x_test))]
  for i in range(len(threshs)): 
    if i == 0: 
      mask = [True for p in range(len(x_test))]
    else: 
      for p in range(len(x_test)): 
        mask[p] = True
        age = x_test_orig[p][-1][-1]
        if (age <= threshs[i-1]) or (age > threshs[i]): 
          mask[p] = False

    auc, acc, rf_preds, rf_probs = evaluate_rnn(model, x_test, y_test, mask)
    all_auc.append(auc)
    all_acc.append(acc)
    all_probs.append(rf_probs)
    all_preds.append(rf_preds)
    all_labels.append(y_test[mask, ...][:, -1])

  return cohorts, all_auc, all_acc, all_probs, all_preds, all_labels



Set up data and evaluate.

In [0]:
cohorts, all_auc, all_acc, all_probs, all_preds, all_labels = run_task_lstm('Mortality Prediction')
print(cohorts) 
print(all_auc)
print(all_acc)
print([len(all_labels[i]) for i in range(len(all_labels))])
print([torch.sum(all_labels[i]).item() for i in range(len(all_labels))])

plt.title('RF ROC Curves')
for i in range(len(cohorts)): 
  fpr, tpr, thresholds = roc_curve(all_labels[i], all_probs[i])
  plt.plot(fpr, tpr, label='%s ROC (area = %0.2f)' % (cohorts[i], all_auc[i]))

plt.plot([0, 1], [0, 1],'r--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('1-Specificity')
plt.ylabel('Sensitivity')
plt.legend(loc="lower right")



Previous Versions.

In [0]:
def train_and_evaluate(X_train, Y_train, X_test, Y_test): 
  num_feats = X_train.shape[2]

  # Hyperparams
  rnn_clip = 1.0
  rnn_hidden_size = 128
  rnn_num_layers = 2
  learning_rate = 1e-3
  rnn_dropout = 0.5
  rnn_bidirectional = False
  rnn_alpha = 0.5

  rnn_model = LSTM_Classifier(input_size=num_feats, hidden_size=rnn_hidden_size, num_layers=rnn_num_layers, dropout=rnn_dropout, bidirectional=rnn_bidirectional)
  criterion = nn.BCEWithLogitsLoss()
  rnn_optimizer = torch.optim.Adam(rnn_model.parameters(), lr=learning_rate)
  ### Train
  losses, norms =  train_model(rnn_model, criterion, rnn_optimizer, X_train, Y_train, n_epochs=10, alpha=rnn_alpha, clip=rnn_clip)

  ### Test
  rnn_model.eval()
  predicted_tags = []
  probs = []

  for i in range(len(X_test)): 
    sent_tensor = torch.tensor(X_test[i]).float()
    sent_tensor = sent_tensor.to(device)
    predicted_tag_id, prob = rnn_predict_one_sent(rnn_model, sent_tensor)
    predicted_tags.append(predicted_tag_id)
    probs.append(prob)

  acc, auc = evaluate_result(Y_test[:, -1], predicted_tags, probs)
  
  return auc, acc, probs, rnn_model


In [0]:

  
X, Y = setup_data(task = 'Mortality Prediction', model = 'LSTM')
x_train, x_test, y_train, y_test = train_test_split(X, Y, test_size=0.2, random_state=1)

scaler = StandardScaler()

x_train = scaler.fit_transform(x_train.reshape(-1, x_train.shape[-1])).reshape(x_train.shape)
x_test = scaler.transform(x_test.reshape(-1, x_test.shape[-1])).reshape(x_test.shape)

x_train = torch.from_numpy(x_train).float()
x_train = torch.reshape(x_train, (x_train.shape[0], x_train.shape[1] * x_train.shape[2]))
y_train = y_train[:, -1]

x_train, y_train = oversample(x_train, y_train)

x_train = torch.from_numpy(x_train).float()
x_train = x_train.reshape(x_train.shape[0], WINDOW_SIZE, (int)(x_train.shape[1] / WINDOW_SIZE))
y_train = torch.from_numpy(y_train).float()
y_train = y_train.unsqueeze(1).expand((y_train.shape[0], WINDOW_SIZE))

auc, acc, probs, model = train_and_evaluate(x_train, y_train, torch.from_numpy(x_test).float(), y_test)

print("AUC of ", auc)
print("ACC of ", acc)

plot_roc('LSTM Mortality Prediction', y_test[:, -1], probs)


In [0]:
model.eval()
# x_test = torch.from_numpy(x_test).float()
predicted_tags = []
probs = []

for i in range(len(x_test)): 
  sent_tensor = x_test[i]
  sent_tensor = sent_tensor.to(device)
  predicted_tag_id, prob = rnn_predict_one_sent(model, sent_tensor)
  predicted_tags.append(predicted_tag_id)
  probs.append(prob)

print(predicted_tags) 
print(y_test[:, -1])
print(probs)
acc, auc = evaluate_result(y_test[:, -1], predicted_tags, probs)

print("AUC of ", auc)
print("ACC of ", acc)

plot_roc('LSTM Mortality Prediction', y_test[:, -1], probs)