## MIMICIII Mechanical Ventilation MAgECs

In [1]:
import numpy as np
import pandas as pd
import psycopg2
import os 
import random
import datetime
from sqlalchemy import create_engine
import matplotlib.pyplot as plt

%matplotlib inline

random.seed(22891)

In [2]:
pd.set_option('display.max_columns', None)

### Get data

In [3]:
# information used to create a database connection
sqluser = 'postgres'
dbname = 'mimic'
schema_name = 'mimiciii'

engine = create_engine("postgresql+psycopg2://{}:{}@/{}".format(sqluser, sqluser, dbname))

schema_name = 'mimiciii'
conn = engine.connect()
conn.execute('SET search_path to ' + schema_name)

df = pd.read_sql("SELECT * FROM mimic_users_study;", conn)
conn.close()

### Featurize

In [4]:
vitals = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 'meanbp_mean',
          'resprate_mean', 'tempc_mean', 'spo2_mean', 'glucose_mean']
labs = ['aniongap', 'albumin', 'bicarbonate', 'bilirubin', 'creatinine', 
        'chloride', 'glucose', 'hemoglobin', 'lactate', 
        'magnesium', 'phosphate', 'platelet', 'potassium', 'ptt', 'inr', 
        'pt', 'sodium', 'bun', 'wbc']  # -hematocrit
comobs = ['congestive_heart_failure', 'chronic_pulmonary', 'pulmonary_circulation']
others = ['age', 'gender']

In [5]:
def last_val(x):
    vals = x[~np.isnan(x)]
    if len(vals):
        return vals[-1]
    else:
        return None
    
def featurize_time(df):
    out = dict()
    for i in range(len(df)):
        for lab in labs:
            val = last_val(df[lab].values[:i+1])
            if lab not in out:
                out[lab] = [val]
            else:
                out[lab].append(val)
        for vital in vitals:    
            val = last_val(df[vital].values[:i+1])
            if vital not in out:
                out[vital] = [val]
            else:
                out[vital].append(val)
        for comob in comobs:    
            val = last_val(df[comob].values[:i+1])
            if comob not in out:
                out[comob] = [val]
            else:
                out[comob].append(val)
        for other in others:
            val = last_val(df[other].values[:i+1])
            if other not in out:
                out[other] = [val]
            else:
                out[other].append(val)
        out['timepoint'] = df.timepoint.values
        out['label'] = [int(x) for x in df.ventilated.values]
    return pd.Series(out)

def featurize(df):
    out = dict()
    for lab in labs:
        out[lab] = last_val(df[lab])
    for vital in vitals:
        out[vital] = last_val(df[vital])
    for comob in comobs:
        out[comob] = last_val(df[comob])
    for other in others:
        out[other] = last_val(df[other])
    out['label'] = int(df.ventilated.iloc[-1])
    return pd.Series(out)

### Example from 'original' dataframe

In [6]:
df[df['subject_id']==4].head()

Unnamed: 0,subject_id,hadm_id,icustay_id,timepoint,event_time,ventilated,mv_start,aniongap,albumin,bicarbonate,bilirubin,creatinine,chloride,glucose,hematocrit,hemoglobin,lactate,magnesium,phosphate,platelet,potassium,ptt,inr,pt,sodium,bun,wbc,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,tempc_mean,spo2_mean,glucose_mean,age,first_icu_stay,adult_icu,first_careunit,diagnosis,curr_service,dischtime,admission_type,mort_icu,gender,admittime,los_icu,mv_hours,los_icu_hr,mv_end,los_hospital,first_hosp_stay,outtime,intime,ventnum,prev_service,transfertime,ethnicity,congestive_heart_failure,cardiac_arrhythmias,valvular_disease,pulmonary_circulation,peripheral_vascular,hypertension,paralysis,other_neurological,chronic_pulmonary,diabetes_uncomplicated,diabetes_complicated,hypothyroidism,renal_failure,liver_disease,peptic_ulcer,aids,lymphoma,metastatic_cancer,solid_tumor,rheumatoid_arthritis,coagulopathy,obesity,weight_loss,fluid_electrolyte,blood_loss_anemia,deficiency_anemias,alcohol_abuse,drug_abuse,psychoses,depression
0,4,185777,294638,3,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,97.0,119.0,69.0,85.666702,28.0,,98.0,,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
1,4,185777,294638,4,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,94.0,,,,,,97.0,153.0,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
2,4,185777,294638,5,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,99.0,133.0,79.0,97.0,26.0,,98.0,,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
3,4,185777,294638,6,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,92.0,,,,24.0,36.666667,97.0,,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
4,4,185777,294638,7,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,89.0,139.0,81.0,100.333,25.0,,97.0,,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0


In [7]:
df[df['subject_id']==4].tail()

Unnamed: 0,subject_id,hadm_id,icustay_id,timepoint,event_time,ventilated,mv_start,aniongap,albumin,bicarbonate,bilirubin,creatinine,chloride,glucose,hematocrit,hemoglobin,lactate,magnesium,phosphate,platelet,potassium,ptt,inr,pt,sodium,bun,wbc,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,tempc_mean,spo2_mean,glucose_mean,age,first_icu_stay,adult_icu,first_careunit,diagnosis,curr_service,dischtime,admission_type,mort_icu,gender,admittime,los_icu,mv_hours,los_icu_hr,mv_end,los_hospital,first_hosp_stay,outtime,intime,ventnum,prev_service,transfertime,ethnicity,congestive_heart_failure,cardiac_arrhythmias,valvular_disease,pulmonary_circulation,peripheral_vascular,hypertension,paralysis,other_neurological,chronic_pulmonary,diabetes_uncomplicated,diabetes_complicated,hypothyroidism,renal_failure,liver_disease,peptic_ulcer,aids,lymphoma,metastatic_cancer,solid_tumor,rheumatoid_arthritis,coagulopathy,obesity,weight_loss,fluid_electrolyte,blood_loss_anemia,deficiency_anemias,alcohol_abuse,drug_abuse,psychoses,depression
19,4,185777,294638,22,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,74.0,101.0,61.0,74.333298,,,100.0,179.333333,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
20,4,185777,294638,23,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,80.0,101.0,57.0,71.666702,,,100.0,,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
21,4,185777,294638,24,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,88.0,,,,,,99.0,,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
22,4,185777,294638,25,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,100.0,116.0,63.0,80.666702,,,98.0,,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0
23,4,185777,294638,26,2191-03-17 03:29:31,0,NaT,17.0,2.8,24.0,2.2,0.5,97.0,140.0,34.2,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,111.0,,,,,37.444445,98.0,,47.0,1,1,MICU,"FEVER,DEHYDRATION,FAILURE TO THRIVE",MED,2191-03-23 18:41:00,EMERGENCY,0,0,2191-03-16 00:28:00,1.0,,40.0,NaT,7.0,1,2191-03-17 16:46:31,2191-03-16 00:29:31,,,2191-03-16 00:29:31,white,0.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0


### Dataframe w/o time (for 'static' models)

In [8]:
df_ml = df.set_index(['subject_id', 'timepoint']).groupby(level=0, group_keys=False).\
                                                  apply(featurize).reset_index()

In [9]:
df_ml[df_ml['subject_id']==4].head()

Unnamed: 0,subject_id,aniongap,albumin,bicarbonate,bilirubin,creatinine,chloride,glucose,hemoglobin,lactate,magnesium,phosphate,platelet,potassium,ptt,inr,pt,sodium,bun,wbc,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,tempc_mean,spo2_mean,glucose_mean,congestive_heart_failure,chronic_pulmonary,pulmonary_circulation,age,gender,label
0,4,17.0,2.8,24.0,2.2,0.5,97.0,140.0,11.5,2.1,1.9,3.2,207.0,3.1,31.3,1.0,12.3,135.0,9.0,9.7,111.0,116.0,63.0,80.666702,18.0,37.444445,98.0,179.333333,0.0,0.0,0.0,47.0,0.0,0.0


### Dataframe w/ time (for 'timepoint' MAgECs)

In [10]:
df_time = df.set_index(['subject_id']).groupby(level=0, group_keys=False).\
                                       apply(featurize_time).apply(pd.Series.explode).reset_index()

In [11]:
df_time[df_time['subject_id']==4].head()

Unnamed: 0,subject_id,aniongap,albumin,bicarbonate,bilirubin,creatinine,chloride,glucose,hemoglobin,lactate,magnesium,phosphate,platelet,potassium,ptt,inr,pt,sodium,bun,wbc,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,tempc_mean,spo2_mean,glucose_mean,congestive_heart_failure,chronic_pulmonary,pulmonary_circulation,age,gender,timepoint,label
0,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,97,119,69,85.6667,28,,98,,0,0,0,47,0,3,0
1,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,94,119,69,85.6667,28,,97,153.0,0,0,0,47,0,4,0
2,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,99,133,79,97.0,26,,98,153.0,0,0,0,47,0,5,0
3,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,92,133,79,97.0,24,36.6667,97,153.0,0,0,0,47,0,6,0
4,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,89,139,81,100.333,25,36.6667,97,153.0,0,0,0,47,0,7,0


In [12]:
df_time[df_time['subject_id']==4].tail()

Unnamed: 0,subject_id,aniongap,albumin,bicarbonate,bilirubin,creatinine,chloride,glucose,hemoglobin,lactate,magnesium,phosphate,platelet,potassium,ptt,inr,pt,sodium,bun,wbc,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,tempc_mean,spo2_mean,glucose_mean,congestive_heart_failure,chronic_pulmonary,pulmonary_circulation,age,gender,timepoint,label
19,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,74,101,61,74.3333,18,36.6667,100,179.333,0,0,0,47,0,22,0
20,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,80,101,57,71.6667,18,36.6667,100,179.333,0,0,0,47,0,23,0
21,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,88,101,57,71.6667,18,36.6667,99,179.333,0,0,0,47,0,24,0
22,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,100,116,63,80.6667,18,36.6667,98,179.333,0,0,0,47,0,25,0
23,4,17,2.8,24,2.2,0.5,97,140,11.5,2.1,1.9,3.2,207,3.1,31.3,1,12.3,135,9,9.7,111,116,63,80.6667,18,37.4444,98,179.333,0,0,0,47,0,26,0


### Train/Valid Split

In [13]:
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler

In [14]:
seed = 7
np.random.seed(seed)

x = df_ml[list(set(df_ml.columns) - {'subject_id', 'label'})]
Y = df_ml[['subject_id', 'label']]

x_train, x_validation, Y_train, Y_validation = train_test_split(x.copy(), Y, test_size=0.2, random_state=seed)

### Impute vitals+labs with mean and co-morbidities with 0

In [15]:
def impute(df):
    df[vitals+labs] = df[vitals+labs].fillna(df[vitals+labs].mean())
    df[comobs] = df[comobs].fillna(0)
    return df

In [16]:
x_train = impute(x_train)
x_validation = impute(x_validation)

### Scale data

In [17]:
from sklearn.preprocessing import StandardScaler

stsc = StandardScaler()
xst_train = stsc.fit_transform(x_train)
xst_train = pd.DataFrame(xst_train, index=x_train.index, columns=x_train.columns)

xst_validation = stsc.transform(x_validation)
xst_validation = pd.DataFrame(xst_validation, index=x_validation.index, columns=x_validation.columns)

### Train 'static' models
These are single timepoint (single row) models. The training data is grouped by patient and all timepoints are condenced to a single 'last' timepoint. 

In [18]:
from sklearn.metrics import accuracy_score
from sklearn.metrics import precision_score
from sklearn.metrics import recall_score
from sklearn.metrics import f1_score
from sklearn.metrics import roc_auc_score
from sklearn.metrics import confusion_matrix


def predict(model, data):
    """
    Model output (predicted) probabilities.
    Wrapper for predict_proba function in scikit-learn models.
    When a model does not have a predict_proba use predict interface.
    """
    if hasattr(model, 'predict_proba'):
        probs = model.predict_proba(data)
        if probs.shape[1] == 2:
            probs = probs[:, 1].ravel()
        else:
            probs = probs.ravel()
    else:
        probs = np.array(model.predict(data))
    return probs


def predict_classes(model, data):
    """
    Model output (predicted) classes.
    """
    if hasattr(model, 'predict_classes'):
        return model.predict_classes(data).ravel()
    else:
         return model.predict(data).ravel()

    
def evaluate(model, x_test, y_test):
    # predict probabilities for test set
    yhat_probs = predict(model, x_test)

    # predict classes for test set
    yhat_classes = predict_classes(model, x_test)
    
    # reduce to 1d array
    if len(yhat_probs[0].shape):
        yhat_probs = yhat_probs[:, 0]
        yhat_classes = yhat_classes[:, 0]
 
    # accuracy: (tp + tn) / (p + n)
    accuracy = accuracy_score(y_test, yhat_classes)
    print('Accuracy: %f' % accuracy)

    # precision tp / (tp + fp)
    precision = precision_score(y_test, yhat_classes)
    print('Precision: %f' % precision)

    # recall: tp / (tp + fn)
    recall = recall_score(y_test, yhat_classes)
    print('Recall: %f' % recall)

    # f1: 2 tp / (2 tp + fp + fn)
    f1 = f1_score(y_test, yhat_classes)
    print('F1 score: %f' % f1)

    # ROC AUC
    auc = roc_auc_score(y_test, yhat_probs)
    print('ROC AUC: %f' % auc)

    # confusion matrix
    matrix = confusion_matrix(y_test, yhat_classes)
    print(matrix)

In [19]:
from sklearn.utils import class_weight
class_weights = class_weight.compute_class_weight('balanced', np.unique(Y_train['label']), Y_train['label'])
class_weights

array([0.61911131, 2.59887711])

#### LR

In [20]:
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression(C=1., class_weight='balanced', solver='lbfgs')
lr.fit(xst_train, Y_train['label'])

LogisticRegression(C=1.0, class_weight='balanced', dual=False,
          fit_intercept=True, intercept_scaling=1, max_iter=100,
          multi_class='warn', n_jobs=None, penalty='l2', random_state=None,
          solver='lbfgs', tol=0.0001, verbose=0, warm_start=False)

In [21]:
evaluate(lr, xst_validation, Y_validation['label'])

Accuracy: 0.650504
Precision: 0.323094
Recall: 0.681055
F1 score: 0.438272
ROC AUC: 0.693912
[[1071  595]
 [ 133  284]]


#### RG

In [22]:
from sklearn.ensemble import RandomForestClassifier
from sklearn.calibration import CalibratedClassifierCV
rf = CalibratedClassifierCV(RandomForestClassifier(n_estimators=800, 
                                                   min_samples_split=2, 
                                                   min_samples_leaf=4, 
                                                   max_features='sqrt', 
                                                   max_depth=90, 
                                                   bootstrap=True, 
                                                   n_jobs=-1),
                            method='sigmoid', cv=5)
rf.fit(xst_train, Y_train['label'])

CalibratedClassifierCV(base_estimator=RandomForestClassifier(bootstrap=True, class_weight=None, criterion='gini',
            max_depth=90, max_features='sqrt', max_leaf_nodes=None,
            min_impurity_decrease=0.0, min_impurity_split=None,
            min_samples_leaf=4, min_samples_split=2,
            min_weight_fraction_leaf=0.0, n_estimators=800, n_jobs=-1,
            oob_score=False, random_state=None, verbose=0,
            warm_start=False),
            cv=5, method='sigmoid')

In [23]:
evaluate(rf, xst_validation, Y_validation['label'])

Accuracy: 0.858377
Precision: 0.912162
Recall: 0.323741
F1 score: 0.477876
ROC AUC: 0.821395
[[1653   13]
 [ 282  135]]


#### MLP

In [24]:
from keras.models import Sequential
from keras.layers import Dense
from keras.layers import Dropout
from keras.wrappers.scikit_learn import KerasClassifier

mlp = Sequential()
mlp.add(Dense(60, input_dim=len(xst_train.columns), activation='relu'))
mlp.add(Dropout(0.2))
mlp.add(Dense(30, input_dim=60, activation='relu'))
mlp.add(Dropout(0.2))
mlp.add(Dense(1, activation='sigmoid'))
mlp.compile(loss='binary_crossentropy', 
            loss_weights=[class_weights[1]], optimizer='adam', metrics=['accuracy'])
mlp.fit(xst_train, Y_train['label'], epochs=100, batch_size=64, verbose=0)

Using TensorFlow backend.


<keras.callbacks.callbacks.History at 0x124c45a90>

In [25]:
evaluate(mlp, xst_validation, Y_validation['label'])

Accuracy: 0.855017
Precision: 0.777778
Recall: 0.386091
F1 score: 0.516026
ROC AUC: 0.798166
[[1620   46]
 [ 256  161]]


### Time-aware (LSTM) model

#### Data pre-processing

In [26]:
# Get train/valid
train_ind = df_time[~np.isin(df_time['subject_id'], Y_validation.subject_id.unique())].index
valid_ind = df_time[np.isin(df_time['subject_id'], Y_validation.subject_id.unique())].index

# Impute
df_series_train = impute(df_time.iloc[train_ind].copy())
df_series_valid = impute(df_time.iloc[valid_ind].copy())

# Get X, Y as numpy arrays
df_series_train_X = df_series_train[list(set(df_series_train.columns) - 
                                         {'subject_id', 'label', 'timepoint'})].astype(float)

df_series_train_Y = df_series_train[['subject_id', 'label', 'timepoint']]

df_series_valid_X = df_series_valid[list(set(df_series_valid.columns) - 
                                         {'subject_id', 'label', 'timepoint'})].astype(float)

df_series_valid_Y = df_series_valid[['subject_id', 'label', 'timepoint']]

# scale
stsc2 = StandardScaler()
tmp = stsc2.fit_transform(df_series_train_X)
df_series_train_X = pd.DataFrame(tmp, index=df_series_train_X.index, columns=df_series_train_X.columns)
tmp = stsc2.transform(df_series_valid_X)
df_series_valid_X = pd.DataFrame(tmp, index=df_series_valid_X.index, columns=df_series_valid_X.columns)

In [27]:
# concat X/Y for train/valid
df_series_train = pd.concat([df_series_train_X, df_series_train_Y], axis=1)
df_series_valid = pd.concat([df_series_valid_X, df_series_valid_Y], axis=1)

In [28]:
df_series_valid.head()

Unnamed: 0,bilirubin,tempc_mean,lactate,glucose,wbc,pulmonary_circulation,albumin,creatinine,sodium,chronic_pulmonary,heartrate_mean,potassium,glucose_mean,age,diasbp_mean,ptt,bicarbonate,gender,magnesium,platelet,bun,chloride,inr,aniongap,phosphate,spo2_mean,meanbp_mean,resprate_mean,congestive_heart_failure,hemoglobin,pt,sysbp_mean,subject_id,label,timepoint
0,0.185142,-0.045702,-0.242366,-0.174261,-0.184449,-0.261491,-1.179369,-0.57922,-0.45567,-0.499044,0.636508,-1.308177,-0.022705,-0.879606,0.42589,-0.187468,0.082431,-1.134352,-0.089849,-0.33529,-0.823223,-0.73269,-0.398609,0.139907,-0.274201,0.465272,0.426152,1.524211,-0.61177,-0.152695,-0.418429,-0.082093,4,0,3
1,0.185142,-0.045702,-0.242366,-0.174261,-0.184449,-0.261491,-1.179369,-0.57922,-0.45567,-0.499044,0.47526,-1.308177,0.173433,-0.879606,0.42589,-0.187468,0.082431,-1.134352,-0.089849,-0.33529,-0.823223,-0.73269,-0.398609,0.139907,-0.274201,0.113185,0.426152,1.524211,-0.61177,-0.152695,-0.418429,-0.082093,4,0,4
2,0.185142,-0.045702,-0.242366,-0.174261,-0.184449,-0.261491,-1.179369,-0.57922,-0.45567,-0.499044,0.744006,-1.308177,0.173433,-0.879606,1.088605,-0.187468,0.082431,-1.134352,-0.089849,-0.33529,-0.823223,-0.73269,-0.398609,0.139907,-0.274201,0.465272,1.162645,1.15837,-0.61177,-0.152695,-0.418429,0.563827,4,0,5
3,0.185142,-0.164977,-0.242366,-0.174261,-0.184449,-0.261491,-1.179369,-0.57922,-0.45567,-0.499044,0.367762,-1.308177,0.173433,-0.879606,1.088605,-0.187468,0.082431,-1.134352,-0.089849,-0.33529,-0.823223,-0.73269,-0.398609,0.139907,-0.274201,0.113185,1.162645,0.792528,-0.61177,-0.152695,-0.418429,0.563827,4,0,6
4,0.185142,-0.164977,-0.242366,-0.174261,-0.184449,-0.261491,-1.179369,-0.57922,-0.45567,-0.499044,0.206514,-1.308177,0.173433,-0.879606,1.221148,-0.187468,0.082431,-1.134352,-0.089849,-0.33529,-0.823223,-0.73269,-0.398609,0.139907,-0.274201,0.113185,1.379239,0.975449,-0.61177,-0.152695,-0.418429,0.84065,4,0,7


In [29]:
def zero_pad(df):
    x = list()
    y = list()    
    series_cols = set(df.columns) - {'subject_id', 'timepoint'}    
    for _, fname in df.set_index(['subject_id']).groupby(level=0, group_keys=False):       
        y_data = np.array(fname['label'].values[0])
        tmp = fname[series_cols].astype(float).values
        x_data = np.zeros([25, tmp.shape[1]])
        x_data[:tmp.shape[0],:] = tmp
        x.append(x_data)
        y.append(y_data)    
    return np.array(x), np.array(y)

In [30]:
xt_train, Yt_train = zero_pad(df_series_train)
xt_valid, Yt_valid = zero_pad(df_series_valid)

In [31]:
len(xt_train), len(xt_valid)

(8332, 2083)

In [32]:
from keras.layers import LSTM
lstm = Sequential()
lstm.add(LSTM(128, dropout=0.5, recurrent_dropout=0.2, input_shape=xt_train.shape[1:]))
lstm.add(Dense(1, activation='sigmoid'))
lstm.compile(loss='binary_crossentropy',
             loss_weights=[class_weights[1]],
             optimizer='adam', 
             metrics=['accuracy'])

In [33]:
lstm.fit(xt_train, Yt_train, epochs=100, batch_size=64, verbose=0)

<keras.callbacks.callbacks.History at 0x12e521b70>

In [34]:
evaluate(lstm, xt_valid, Yt_valid)

Accuracy: 0.983677
Precision: 0.930337
Recall: 0.992806
F1 score: 0.960557
ROC AUC: 0.998674
[[1635   31]
 [   3  414]]


### MAgECs

In [35]:
import magec_utils as mg

In [44]:
x_magec = xst_validation.copy()
x_magec['timepoint'] = 0
x_magec['case'] = np.arange(len(x_magec))
x_magec.set_index(['case','timepoint'], inplace=True)

y_magec = pd.DataFrame(Y_validation['label'].copy())
y_magec['timepoint'] = 0
y_magec['case'] = np.arange(len(y_magec))
y_magec.set_index(['case','timepoint'], inplace=True)

In [41]:
magecs_lr = mg.case_magecs(lr, x_magec, model_name='lr')
magecs_lr = mg.normalize_magecs(magecs_lr, features=None, model_name='lr')

magecs_rf = mg.case_magecs(rf, x_magec, model_name='rf')
magecs_rf = mg.normalize_magecs(magecs_rf, features=None, model_name='rf')

magecs_mlp = mg.case_magecs(mlp, x_magec, model_name='mlp')
magecs_mlp = mg.normalize_magecs(magecs_mlp, features=None, model_name='mlp')

In [47]:
x_magec.columns

Index(['bilirubin', 'tempc_mean', 'lactate', 'glucose', 'wbc',
       'pulmonary_circulation', 'albumin', 'creatinine', 'sodium',
       'chronic_pulmonary', 'heartrate_mean', 'potassium', 'glucose_mean',
       'age', 'diasbp_mean', 'ptt', 'bicarbonate', 'gender', 'magnesium',
       'platelet', 'bun', 'chloride', 'inr', 'aniongap', 'phosphate',
       'spo2_mean', 'meanbp_mean', 'resprate_mean', 'congestive_heart_failure',
       'hemoglobin', 'pt', 'sysbp_mean'],
      dtype='object')

In [118]:
feats = vitals+labs+comobs
joined = mg.magec_models(magecs_mlp, magecs_rf, magecs_lr, Xdata=x_magec, Ydata=y_magec, features=feats)

In [120]:
joined.head(3)

Unnamed: 0,case,timepoint,mlp_bilirubin,mlp_tempc_mean,mlp_lactate,mlp_glucose,mlp_wbc,mlp_pulmonary_circulation,mlp_albumin,mlp_creatinine,mlp_sodium,mlp_chronic_pulmonary,mlp_heartrate_mean,mlp_potassium,mlp_glucose_mean,mlp_diasbp_mean,mlp_ptt,mlp_bicarbonate,mlp_magnesium,mlp_platelet,mlp_bun,mlp_chloride,mlp_inr,mlp_aniongap,mlp_phosphate,mlp_spo2_mean,mlp_meanbp_mean,mlp_resprate_mean,mlp_congestive_heart_failure,mlp_hemoglobin,mlp_pt,mlp_sysbp_mean,perturb_bilirubin_prob_mlp,perturb_tempc_mean_prob_mlp,perturb_lactate_prob_mlp,perturb_glucose_prob_mlp,perturb_wbc_prob_mlp,perturb_pulmonary_circulation_prob_mlp,perturb_albumin_prob_mlp,perturb_creatinine_prob_mlp,perturb_sodium_prob_mlp,perturb_chronic_pulmonary_prob_mlp,perturb_heartrate_mean_prob_mlp,perturb_potassium_prob_mlp,perturb_glucose_mean_prob_mlp,perturb_diasbp_mean_prob_mlp,perturb_ptt_prob_mlp,perturb_bicarbonate_prob_mlp,perturb_magnesium_prob_mlp,perturb_platelet_prob_mlp,perturb_bun_prob_mlp,perturb_chloride_prob_mlp,perturb_inr_prob_mlp,perturb_aniongap_prob_mlp,perturb_phosphate_prob_mlp,perturb_spo2_mean_prob_mlp,perturb_meanbp_mean_prob_mlp,perturb_resprate_mean_prob_mlp,perturb_congestive_heart_failure_prob_mlp,perturb_hemoglobin_prob_mlp,perturb_pt_prob_mlp,perturb_sysbp_mean_prob_mlp,orig_prob_mlp,bilirubin,tempc_mean,lactate,glucose,wbc,pulmonary_circulation,albumin,creatinine,sodium,chronic_pulmonary,heartrate_mean,potassium,glucose_mean,age,diasbp_mean,ptt,bicarbonate,gender,magnesium,platelet,bun,chloride,inr,aniongap,phosphate,spo2_mean,meanbp_mean,resprate_mean,congestive_heart_failure,hemoglobin,pt,sysbp_mean,label,rf_bilirubin,rf_tempc_mean,rf_lactate,rf_glucose,rf_wbc,rf_pulmonary_circulation,rf_albumin,rf_creatinine,rf_sodium,rf_chronic_pulmonary,rf_heartrate_mean,rf_potassium,rf_glucose_mean,rf_diasbp_mean,rf_ptt,rf_bicarbonate,rf_magnesium,rf_platelet,rf_bun,rf_chloride,rf_inr,rf_aniongap,rf_phosphate,rf_spo2_mean,rf_meanbp_mean,rf_resprate_mean,rf_congestive_heart_failure,rf_hemoglobin,rf_pt,rf_sysbp_mean,perturb_bilirubin_prob_rf,perturb_tempc_mean_prob_rf,perturb_lactate_prob_rf,perturb_glucose_prob_rf,perturb_wbc_prob_rf,perturb_pulmonary_circulation_prob_rf,perturb_albumin_prob_rf,perturb_creatinine_prob_rf,perturb_sodium_prob_rf,perturb_chronic_pulmonary_prob_rf,perturb_heartrate_mean_prob_rf,perturb_potassium_prob_rf,perturb_glucose_mean_prob_rf,perturb_diasbp_mean_prob_rf,perturb_ptt_prob_rf,perturb_bicarbonate_prob_rf,perturb_magnesium_prob_rf,perturb_platelet_prob_rf,perturb_bun_prob_rf,perturb_chloride_prob_rf,perturb_inr_prob_rf,perturb_aniongap_prob_rf,perturb_phosphate_prob_rf,perturb_spo2_mean_prob_rf,perturb_meanbp_mean_prob_rf,perturb_resprate_mean_prob_rf,perturb_congestive_heart_failure_prob_rf,perturb_hemoglobin_prob_rf,perturb_pt_prob_rf,perturb_sysbp_mean_prob_rf,orig_prob_rf,lr_bilirubin,lr_tempc_mean,lr_lactate,lr_glucose,lr_wbc,lr_pulmonary_circulation,lr_albumin,lr_creatinine,lr_sodium,lr_chronic_pulmonary,lr_heartrate_mean,lr_potassium,lr_glucose_mean,lr_diasbp_mean,lr_ptt,lr_bicarbonate,lr_magnesium,lr_platelet,lr_bun,lr_chloride,lr_inr,lr_aniongap,lr_phosphate,lr_spo2_mean,lr_meanbp_mean,lr_resprate_mean,lr_congestive_heart_failure,lr_hemoglobin,lr_pt,lr_sysbp_mean,perturb_bilirubin_prob_lr,perturb_tempc_mean_prob_lr,perturb_lactate_prob_lr,perturb_glucose_prob_lr,perturb_wbc_prob_lr,perturb_pulmonary_circulation_prob_lr,perturb_albumin_prob_lr,perturb_creatinine_prob_lr,perturb_sodium_prob_lr,perturb_chronic_pulmonary_prob_lr,perturb_heartrate_mean_prob_lr,perturb_potassium_prob_lr,perturb_glucose_mean_prob_lr,perturb_diasbp_mean_prob_lr,perturb_ptt_prob_lr,perturb_bicarbonate_prob_lr,perturb_magnesium_prob_lr,perturb_platelet_prob_lr,perturb_bun_prob_lr,perturb_chloride_prob_lr,perturb_inr_prob_lr,perturb_aniongap_prob_lr,perturb_phosphate_prob_lr,perturb_spo2_mean_prob_lr,perturb_meanbp_mean_prob_lr,perturb_resprate_mean_prob_lr,perturb_congestive_heart_failure_prob_lr,perturb_hemoglobin_prob_lr,perturb_pt_prob_lr,perturb_sysbp_mean_prob_lr,orig_prob_lr
0,0,0,-0.114349,-0.235846,-0.006834,-0.02387,-0.012122,0.0,-0.154228,0.111788,0.057085,0.0,-0.006499,0.014651,0.030535,0.52724,0.018557,-0.064537,-0.154154,-0.070255,0.225778,-0.038726,0.05429,-0.442167,0.006308,0.218972,-0.504774,0.084505,0.0,-0.075772,0.013988,0.107361,0.063463,0.049593,0.078674,0.076059,0.077854,0.079746,0.058552,0.099268,0.089233,0.079746,0.078726,0.082089,0.0847,0.212174,0.082724,0.070136,0.058561,0.069338,0.123443,0.073844,0.088746,0.032397,0.080747,0.121868,0.028431,0.094144,0.079746,0.068576,0.081982,0.09842,0.079746,-0.35271,-1.826785,-0.031654,-0.600679,0.128995,-0.26247,1.841195,-0.300016,0.437431,-0.5003,0.105532,-0.447039,-0.644794,-1.378663,0.97873,-0.387913,0.665911,0.875002,0.900106,-0.284242,-0.739836,1.105866,-0.398171,-1.686628,-1.205328,-1.02749,0.701371,-0.418527,-0.624917,0.705209,-0.460111,0.518912,0.0,0.001606,0.075255,-0.007008,0.009196,-0.003086,0.0,-0.05456,0.064784,0.012914,0.0,0.189849,0.026456,0.106302,0.651558,0.093702,-0.027306,-0.042805,-0.007686,0.040905,-0.026778,0.013197,-0.134682,-0.009449,0.222465,0.386178,0.415402,0.0,0.017093,-0.008118,0.310347,0.07135,0.078807,0.070521,0.072087,0.070897,0.071195,0.066106,0.077705,0.072451,0.071195,0.091834,0.07379,0.08216,0.165522,0.080784,0.068604,0.067173,0.070457,0.075244,0.068654,0.072479,0.059244,0.070288,0.095881,0.118684,0.123217,0.071195,0.072862,0.070415,0.107589,0.071195,0.042911,0.064272,-0.005507,-0.06194,0.025506,0.0,-0.065165,-0.023988,0.021273,0.0,-0.013071,-0.017383,-0.003778,0.567373,0.054354,0.015028,-0.090214,-0.02728,-0.065582,0.064708,-0.008887,-0.531647,0.288921,-0.143238,-0.411765,0.114449,0.0,-0.084871,-0.054683,0.126002,0.464989,0.470307,0.452966,0.439022,0.460662,0.454331,0.438228,0.44839,0.45961,0.454331,0.451092,0.450025,0.453394,0.594892,0.467837,0.458059,0.432071,0.447576,0.438126,0.470416,0.452128,0.328523,0.526414,0.419108,0.355493,0.482824,0.454331,0.433382,0.44081,0.48571,0.454331
1,1,0,-0.43348,0.016087,-0.581537,0.029451,-0.248602,0.0,-0.077667,0.051383,0.126043,0.0,-0.009515,0.09279,0.001142,0.011603,0.320031,0.006838,-0.037614,0.010104,-0.042318,-0.001139,-0.306989,-0.189114,-0.067046,-0.012096,0.035869,-0.013175,0.0,0.235972,-0.131398,-0.038791,0.995203,0.998309,0.993243,0.998361,0.996874,0.998245,0.997898,0.998442,0.99869,0.998245,0.998205,0.998585,0.998249,0.998291,0.999165,0.998272,0.998084,0.998285,0.998063,0.99824,0.996421,0.997277,0.997949,0.998195,0.998385,0.99819,0.998245,0.998985,0.997618,0.998079,0.998245,-0.413196,-0.02881,-0.839728,-0.522573,-0.3714,-0.26247,0.410005,-0.07215,0.437431,-0.5003,-0.006108,0.538677,-0.514202,0.434554,-0.040753,-0.329892,0.665911,0.875002,0.302222,-0.03518,-0.106741,0.173632,-0.323868,-0.23196,0.400326,-0.008735,-0.051094,0.016036,-0.624917,-0.710615,-0.322022,-0.063619,1.0,-0.245606,0.04452,-0.19101,-0.022999,-0.053551,0.0,-0.162873,-0.000848,0.00226,0.0,0.000984,0.012515,0.081693,0.480561,0.017951,-0.02186,-0.004465,0.000274,-0.007289,0.003112,-0.027859,-0.015339,-0.030668,0.083569,0.473152,0.421431,0.0,0.021414,-0.049796,0.468171,0.570084,0.691301,0.594067,0.664688,0.65228,0.673882,0.606261,0.673546,0.674779,0.673882,0.674273,0.67883,0.705442,0.831145,0.680967,0.665147,0.672108,0.673991,0.670983,0.675116,0.662729,0.667765,0.661594,0.706145,0.829259,0.815622,0.673882,0.682324,0.653817,0.827981,0.673882,0.195,0.003932,-0.566742,-0.209028,-0.284867,0.0,-0.05629,-0.022378,0.082518,0.0,0.002934,0.081255,-0.011686,-0.091642,0.179307,0.058296,-0.1175,-0.013098,-0.036704,0.039411,-0.028041,-0.283628,-0.372236,-0.004724,0.116361,-0.01701,0.0,0.331749,-0.148459,-0.059924,0.560311,0.548141,0.511507,0.53451,0.529642,0.54789,0.544293,0.546461,0.553154,0.54789,0.548078,0.553074,0.547144,0.542031,0.559314,0.55161,0.540376,0.547054,0.545545,0.550406,0.546099,0.529722,0.524027,0.547588,0.55531,0.546804,0.54789,0.568977,0.538393,0.544061,0.54789
2,2,0,-0.002917,0.071149,-0.065245,0.05602,0.011676,0.0,-0.001042,0.016279,-0.098757,0.0,-0.451479,0.000857,0.200035,-0.302756,0.102591,0.092758,-0.058322,-0.063632,-0.023055,0.009261,-0.022212,-0.302103,-0.0847,-0.625747,0.276193,0.159173,0.0,0.059023,0.021961,0.099989,0.147847,0.157621,0.140014,0.155582,0.149732,0.148222,0.148088,0.150331,0.135949,0.148222,0.098932,0.148332,0.175872,0.113302,0.161926,0.160569,0.140867,0.140213,0.145277,0.149419,0.145384,0.113369,0.137642,0.084176,0.187417,0.169912,0.148222,0.155985,0.151072,0.161566,0.148222,-0.035112,1.485625,0.275666,-0.322968,-0.155845,-0.26247,-0.013588,-0.356983,-0.475392,-0.5003,0.878376,-0.008943,-0.478586,0.84399,-0.556382,-0.289724,0.665911,0.875002,-0.494957,-0.735665,-0.655424,0.484376,-0.323868,-0.855389,0.323866,-2.709738,-0.53071,-1.889047,-0.624917,-0.877183,-0.368052,0.387472,0.0,-0.025973,0.321763,-0.039406,0.032697,0.02189,0.0,0.0,0.043925,-0.059588,0.0,0.328485,-0.00019,0.097192,0.446459,0.068037,-0.010355,-0.011337,-0.064283,0.029567,0.014077,0.013231,-0.047499,-0.02514,0.312599,0.430172,0.361714,0.0,0.009258,-0.044524,0.373156,0.097219,0.137897,0.095886,0.103238,0.102105,0.099844,0.099844,0.104427,0.093914,0.099844,0.138809,0.099825,0.110232,0.155645,0.10702,0.09879,0.098691,0.09346,0.102909,0.101293,0.101205,0.095091,0.097303,0.136662,0.153225,0.14339,0.099844,0.100795,0.095383,0.144997,0.099844,0.004893,-0.059865,0.054932,-0.038143,-0.035293,0.0,0.000551,-0.032691,-0.026478,0.0,-0.124603,-0.000398,-0.003211,-0.36941,0.046495,0.017212,0.056817,-0.080867,-0.066542,0.032462,-0.008279,-0.308815,-0.088914,-0.43265,0.356852,0.591647,0.0,0.12091,-0.050099,0.107759,0.558539,0.544554,0.569283,0.549254,0.54987,0.557485,0.557604,0.550432,0.551774,0.557485,0.530503,0.5574,0.556794,0.477117,0.567476,0.56119,0.569687,0.540002,0.543108,0.564466,0.555701,0.49033,0.538257,0.463362,0.632407,0.67865,0.557485,0.583349,0.546668,0.580556,0.557485


In [121]:
ranks1 = mg.magec_rank(joined, rank=1, features=feats)

In [122]:
ranks1.head()

Unnamed: 0,case,timepoint,mlp_magec,mlp_feat,rf_magec,rf_feat,lr_magec,lr_feat,perturb_heartrate_mean_prob_mlp,perturb_heartrate_mean_prob_rf,perturb_heartrate_mean_prob_lr,perturb_sysbp_mean_prob_mlp,perturb_sysbp_mean_prob_rf,perturb_sysbp_mean_prob_lr,perturb_diasbp_mean_prob_mlp,perturb_diasbp_mean_prob_rf,perturb_diasbp_mean_prob_lr,perturb_meanbp_mean_prob_mlp,perturb_meanbp_mean_prob_rf,perturb_meanbp_mean_prob_lr,perturb_resprate_mean_prob_mlp,perturb_resprate_mean_prob_rf,perturb_resprate_mean_prob_lr,perturb_tempc_mean_prob_mlp,perturb_tempc_mean_prob_rf,perturb_tempc_mean_prob_lr,perturb_spo2_mean_prob_mlp,perturb_spo2_mean_prob_rf,perturb_spo2_mean_prob_lr,perturb_glucose_mean_prob_mlp,perturb_glucose_mean_prob_rf,perturb_glucose_mean_prob_lr,perturb_aniongap_prob_mlp,perturb_aniongap_prob_rf,perturb_aniongap_prob_lr,perturb_albumin_prob_mlp,perturb_albumin_prob_rf,perturb_albumin_prob_lr,perturb_bicarbonate_prob_mlp,perturb_bicarbonate_prob_rf,perturb_bicarbonate_prob_lr,perturb_bilirubin_prob_mlp,perturb_bilirubin_prob_rf,perturb_bilirubin_prob_lr,perturb_creatinine_prob_mlp,perturb_creatinine_prob_rf,perturb_creatinine_prob_lr,perturb_chloride_prob_mlp,perturb_chloride_prob_rf,perturb_chloride_prob_lr,perturb_glucose_prob_mlp,perturb_glucose_prob_rf,perturb_glucose_prob_lr,perturb_hemoglobin_prob_mlp,perturb_hemoglobin_prob_rf,perturb_hemoglobin_prob_lr,perturb_lactate_prob_mlp,perturb_lactate_prob_rf,perturb_lactate_prob_lr,perturb_magnesium_prob_mlp,perturb_magnesium_prob_rf,perturb_magnesium_prob_lr,perturb_phosphate_prob_mlp,perturb_phosphate_prob_rf,perturb_phosphate_prob_lr,perturb_platelet_prob_mlp,perturb_platelet_prob_rf,perturb_platelet_prob_lr,perturb_potassium_prob_mlp,perturb_potassium_prob_rf,perturb_potassium_prob_lr,perturb_ptt_prob_mlp,perturb_ptt_prob_rf,perturb_ptt_prob_lr,perturb_inr_prob_mlp,perturb_inr_prob_rf,perturb_inr_prob_lr,perturb_pt_prob_mlp,perturb_pt_prob_rf,perturb_pt_prob_lr,perturb_sodium_prob_mlp,perturb_sodium_prob_rf,perturb_sodium_prob_lr,perturb_bun_prob_mlp,perturb_bun_prob_rf,perturb_bun_prob_lr,perturb_wbc_prob_mlp,perturb_wbc_prob_rf,perturb_wbc_prob_lr,perturb_congestive_heart_failure_prob_mlp,perturb_congestive_heart_failure_prob_rf,perturb_congestive_heart_failure_prob_lr,perturb_chronic_pulmonary_prob_mlp,perturb_chronic_pulmonary_prob_rf,perturb_chronic_pulmonary_prob_lr,perturb_pulmonary_circulation_prob_mlp,perturb_pulmonary_circulation_prob_rf,perturb_pulmonary_circulation_prob_lr,orig_prob_mlp,orig_prob_rf,orig_prob_lr,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,tempc_mean,spo2_mean,glucose_mean,aniongap,albumin,bicarbonate,bilirubin,creatinine,chloride,glucose,hemoglobin,lactate,magnesium,phosphate,platelet,potassium,ptt,inr,pt,sodium,bun,wbc,congestive_heart_failure,chronic_pulmonary,pulmonary_circulation
0,0,0,-0.504774,meanbp_mean,-0.134682,aniongap,-0.531647,aniongap,0.078726,0.091834,0.451092,0.09842,0.107589,0.48571,0.212174,0.165522,0.594892,0.028431,0.118684,0.355493,0.094144,0.123217,0.482824,0.049593,0.078807,0.470307,0.121868,0.095881,0.419108,0.0847,0.08216,0.453394,0.032397,0.059244,0.328523,0.058552,0.066106,0.438228,0.070136,0.068604,0.458059,0.063463,0.07135,0.464989,0.099268,0.077705,0.44839,0.073844,0.068654,0.470416,0.076059,0.072087,0.439022,0.068576,0.072862,0.433382,0.078674,0.070521,0.452966,0.058561,0.067173,0.432071,0.080747,0.070288,0.526414,0.069338,0.070457,0.447576,0.082089,0.07379,0.450025,0.082724,0.080784,0.467837,0.088746,0.072479,0.452128,0.081982,0.070415,0.44081,0.089233,0.072451,0.45961,0.123443,0.075244,0.438126,0.077854,0.070897,0.460662,0.079746,0.071195,0.454331,0.079746,0.071195,0.454331,0.079746,0.071195,0.454331,0.079746,0.071195,0.454331,0.105532,0.518912,0.97873,0.701371,-0.418527,-1.826785,-1.02749,-0.644794,-1.686628,1.841195,0.665911,-0.35271,-0.300016,1.105866,-0.600679,0.705209,-0.031654,0.900106,-1.205328,-0.284242,-0.447039,-0.387913,-0.398171,-0.460111,0.437431,-0.739836,0.128995,-0.624917,-0.5003,-0.26247
1,1,0,-0.581537,lactate,-0.245606,bilirubin,-0.566742,lactate,0.998205,0.674273,0.548078,0.998079,0.827981,0.544061,0.998291,0.831145,0.542031,0.998385,0.829259,0.55531,0.99819,0.815622,0.546804,0.998309,0.691301,0.548141,0.998195,0.706145,0.547588,0.998249,0.705442,0.547144,0.997277,0.667765,0.529722,0.997898,0.606261,0.544293,0.998272,0.665147,0.55161,0.995203,0.570084,0.560311,0.998442,0.673546,0.546461,0.99824,0.675116,0.550406,0.998361,0.664688,0.53451,0.998985,0.682324,0.568977,0.993243,0.594067,0.511507,0.998084,0.672108,0.540376,0.997949,0.661594,0.524027,0.998285,0.673991,0.547054,0.998585,0.67883,0.553074,0.999165,0.680967,0.559314,0.996421,0.662729,0.546099,0.997618,0.653817,0.538393,0.99869,0.674779,0.553154,0.998063,0.670983,0.545545,0.996874,0.65228,0.529642,0.998245,0.673882,0.54789,0.998245,0.673882,0.54789,0.998245,0.673882,0.54789,0.998245,0.673882,0.54789,-0.006108,-0.063619,-0.040753,-0.051094,0.016036,-0.02881,-0.008735,-0.514202,-0.23196,0.410005,0.665911,-0.413196,-0.07215,0.173632,-0.522573,-0.710615,-0.839728,0.302222,0.400326,-0.03518,0.538677,-0.329892,-0.323868,-0.322022,0.437431,-0.106741,-0.3714,-0.624917,-0.5003,-0.26247
2,2,0,-0.625747,spo2_mean,-0.064283,platelet,-0.43265,spo2_mean,0.098932,0.138809,0.530503,0.161566,0.144997,0.580556,0.113302,0.155645,0.477117,0.187417,0.153225,0.632407,0.169912,0.14339,0.67865,0.157621,0.137897,0.544554,0.084176,0.136662,0.463362,0.175872,0.110232,0.556794,0.113369,0.095091,0.49033,0.148088,0.099844,0.557604,0.160569,0.09879,0.56119,0.147847,0.097219,0.558539,0.150331,0.104427,0.550432,0.149419,0.101293,0.564466,0.155582,0.103238,0.549254,0.155985,0.100795,0.583349,0.140014,0.095886,0.569283,0.140867,0.098691,0.569687,0.137642,0.097303,0.538257,0.140213,0.09346,0.540002,0.148332,0.099825,0.5574,0.161926,0.10702,0.567476,0.145384,0.101205,0.555701,0.151072,0.095383,0.546668,0.135949,0.093914,0.551774,0.145277,0.102909,0.543108,0.149732,0.102105,0.54987,0.148222,0.099844,0.557485,0.148222,0.099844,0.557485,0.148222,0.099844,0.557485,0.148222,0.099844,0.557485,0.878376,0.387472,-0.556382,-0.53071,-1.889047,1.485625,-2.709738,-0.478586,-0.855389,-0.013588,0.665911,-0.035112,-0.356983,0.484376,-0.322968,-0.877183,0.275666,-0.494957,0.323866,-0.735665,-0.008943,-0.289724,-0.323868,-0.368052,-0.475392,-0.655424,-0.155845,-0.624917,-0.5003,-0.26247
3,3,0,-0.420867,phosphate,-0.291951,ptt,-0.593165,phosphate,0.000274,0.233019,0.437168,0.000538,0.252624,0.603595,0.000299,0.251999,0.430798,0.000181,0.230127,0.357189,0.000421,0.262253,0.507422,0.000497,0.249749,0.489951,0.001146,0.39666,0.516654,0.000332,0.227009,0.481721,0.001204,0.235178,0.670669,0.000445,0.183945,0.499452,4e-05,0.218342,0.461229,0.000617,0.220729,0.492988,0.010871,0.21409,0.556654,0.000488,0.211003,0.469845,0.000416,0.207659,0.465734,0.00102,0.20855,0.455811,0.000343,0.212744,0.457154,0.000555,0.200392,0.510563,2.5e-05,0.20194,0.196295,0.000388,0.206982,0.469115,0.000324,0.214094,0.49587,0.000482,0.159215,0.313986,0.000494,0.207005,0.48343,0.035262,0.194098,0.74706,0.001874,0.201506,0.450921,0.000135,0.215283,0.549866,0.000505,0.198142,0.450143,0.000504,0.209399,0.483186,0.000504,0.209399,0.483186,0.000504,0.209399,0.483186,0.000504,0.209399,0.483186,1.496652,2.00856,-0.364493,0.886183,-0.354591,-0.769633,0.96062,-1.000955,2.469566,-1.839008,-3.904048,-0.322467,3.687648,-0.913975,-0.678785,0.913419,-0.600715,-1.092842,5.599586,-0.587785,1.305344,5.097254,0.043779,9.677909,-2.666168,3.016528,-0.671637,-0.624917,-0.5003,-0.26247
4,4,0,-0.14502,bicarbonate,-0.056514,chronic_pulmonary,-0.555587,diasbp_mean,0.005274,0.120544,0.229112,0.006965,0.127264,0.258863,0.006713,0.130469,0.139589,0.005738,0.107861,0.218313,0.005824,0.107169,0.18371,0.005127,0.11674,0.215648,0.006618,0.19754,0.227055,0.006733,0.097917,0.218225,0.009455,0.101805,0.265177,0.005734,0.092832,0.218715,0.004246,0.091379,0.214027,0.005595,0.089945,0.224126,0.005781,0.092728,0.220767,0.00563,0.090425,0.226685,0.005861,0.09314,0.219365,0.012035,0.092961,0.244534,0.022723,0.088837,0.316864,0.005895,0.092452,0.20687,0.005465,0.093479,0.211713,0.01221,0.091707,0.254958,0.006956,0.093661,0.22668,0.006482,0.10015,0.22762,0.006493,0.091196,0.217965,0.005735,0.092411,0.21887,0.006678,0.092202,0.222288,0.006016,0.092613,0.244214,0.009945,0.091433,0.282928,0.005752,0.092832,0.218632,0.005662,0.087579,0.22273,0.005752,0.092832,0.218632,0.005752,0.092832,0.218632,-0.486982,0.91323,-0.94016,0.00319,0.796251,0.498952,0.348894,-0.407354,0.807088,-0.013588,-1.203618,-0.26198,0.155717,0.795121,0.041527,-1.210318,2.904808,0.700811,0.170947,2.097408,1.19582,-0.37006,-0.175263,0.011692,0.437431,1.623718,1.737957,-0.624917,1.998801,-0.26247


In [123]:
mg.print_ranks_stats(ranks1)

	 mlp MAgEC Stats
**** mlp_feat ****
meanbp_mean                 410
diasbp_mean                 390
lactate                     182
congestive_heart_failure    121
resprate_mean               107
bilirubin                   100
heartrate_mean               91
sysbp_mean                   59
tempc_mean                   49
chloride                     48
platelet                     47
aniongap                     47
spo2_mean                    47
albumin                      45
phosphate                    43
sodium                       36
bicarbonate                  34
chronic_pulmonary            31
hemoglobin                   27
inr                          25
glucose_mean                 23
wbc                          21
magnesium                    20
ptt                          19
pulmonary_circulation        17
bun                          16
glucose                      11
potassium                     6
pt                            6
creatinine                    5
Nam

In [124]:
weights = {'mlp': 0.798166, 'rf': 0.821395, 'lr': 0.693912}
consensus1 = mg.magec_consensus(ranks1, use_weights=True, weights=weights)

In [149]:
consensus1.winner.value_counts()

diasbp_mean                 434
meanbp_mean                 424
congestive_heart_failure    210
lactate                     186
resprate_mean               141
aniongap                    109
bilirubin                    72
heartrate_mean               57
phosphate                    56
ptt                          52
albumin                      47
sysbp_mean                   38
platelet                     37
spo2_mean                    35
hemoglobin                   29
chloride                     25
tempc_mean                   22
sodium                       17
chronic_pulmonary            15
pulmonary_circulation        14
bicarbonate                  14
magnesium                    12
inr                          12
wbc                           7
glucose_mean                  7
glucose                       5
bun                           4
creatinine                    1
potassium                     1
Name: winner, dtype: int64

In [150]:
consensus1.head()

Unnamed: 0,case,timepoint,winner,score,consensus,models,avg_percent_consensus,avg_percent_all
0,0,0,aniongap,0.479543,2,"[lr, rf]",22.23811,34.617193
1,1,0,lactate,0.857432,2,"[lr, mlp]",3.570754,6.32854
2,2,0,spo2_mean,0.799671,2,"[lr, mlp]",30.046567,7.73922
3,3,0,phosphate,0.747526,2,"[lr, mlp]",77.219463,52.66697
4,4,0,diasbp_mean,0.385529,1,[lr],36.153488,-7.03655


In [151]:
consensus1[['winner','score','consensus']].groupby(['winner']).\
                                           agg(['mean','std', 'count']).reset_index().\
                                           sort_values([('score', 'count')], ascending=False)

Unnamed: 0_level_0,winner,score,score,score,consensus,consensus,consensus
Unnamed: 0_level_1,Unnamed: 1_level_1,mean,std,count,mean,std,count
9,diasbp_mean,0.730556,0.265143,434,1.739631,0.439342,434
17,meanbp_mean,0.814921,0.280102,424,1.846698,0.360704,424
7,congestive_heart_failure,0.608454,0.202633,210,1.752381,0.638194,210
15,lactate,0.687871,0.242851,186,1.844086,0.542659,186
23,resprate_mean,0.749922,0.410418,141,1.780142,0.7568,141
1,aniongap,0.491943,0.189519,109,1.376147,0.557591,109
3,bilirubin,0.782827,0.368894,72,2.027778,0.604495,72
12,heartrate_mean,0.471002,0.137775,57,1.157895,0.367884,57
18,phosphate,0.611081,0.264941,56,1.678571,0.606245,56
21,ptt,0.71371,0.257576,52,1.961538,0.558764,52
