In [1]:
import numpy as np
import pandas as pd
import psycopg2
import os 
import random
import datetime
from sqlalchemy import create_engine
import mimic_utils as mimic
import magec_utils as mg
import matplotlib.pyplot as plt
from adjustText import adjust_text

pd.set_option('display.max_columns', None)

random.seed(22891)

%matplotlib inline

Using TensorFlow backend.


In [2]:
def get_magecs(p=None, models=('lr', 'mlp', 'svm', 'lstm')):
    prefix = 'magec_'
    postfix = '_0.csv' if p is None else '_p' + str(p) + '.csv'
    magecs = []
    for m in models:
        filename = prefix + m + postfix
        magecs.append(pd.read_csv(filename))
    return magecs

In [4]:
vitals = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 'meanbp_mean',
          'resprate_mean', 'tempc_mean', 'spo2_mean', 'glucose_mean']

labs = ['aniongap', 'albumin', 'bicarbonate', 'bilirubin', 'creatinine', 
        'chloride', 'glucose', 'hemoglobin', 'lactate', 
        'magnesium', 'phosphate', 'platelet', 'potassium', 'ptt', 'inr', 
        'pt', 'sodium', 'bun', 'wbc']  # -hematocrit

comobs = ['congestive_heart_failure', 'chronic_pulmonary', 'pulmonary_circulation']

others = ['age', 'gender']

features = vitals+labs

df_cohort = mimic.get_mimic_data()

df_ml = mimic.get_ml_data(df_cohort)

df_time = mimic.get_ml_series_data(df_cohort)

_, _, _, _, _, _, Y_validation = mimic.train_valid_ml(df_ml)

stsc2, series_means, _, df_series_valid, _, _, xt_valid, Yt_valid = mimic.train_valid_series(df_time, Y_validation)

In [5]:
df_notes = mimic.get_cohort_notes()

In [6]:
magecs = get_magecs()

In [7]:
x_magec_cols = list(set(df_series_valid.columns) - {'label'})
x_magec = df_series_valid[x_magec_cols]
y_magec = df_series_valid['label']

In [8]:
joined = mg.magec_models(*magecs, Xdata=x_magec, Ydata=y_magec, features=features)

In [9]:
def most_abnormal(x, features):
    res = None
    feat = None
    for f in features:
        if f == 'resprate_mean':
            continue
        if res is None or abs(x[f]) > res:
            res = abs(x[f])
            feat = f
    return feat

In [10]:
prob_cols = [c for c in joined.columns if c.startswith('perturb')]
joined['orig_prob_ensemble'] = joined[['orig_prob_mlp', 'orig_prob_lr', 
                                       'orig_prob_svm', 'orig_prob_lstm']].apply(np.mean, 1)
joined[['best_feat', 'new_risk']] = joined.apply(lambda x: mimic.best_feature(x, prob_cols), axis=1)
joined['most_abnormal'] = joined.apply(lambda x: most_abnormal(x, features), axis=1)

In [11]:
joined[features+['most_abnormal','best_feat']].head()

Unnamed: 0,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,tempc_mean,spo2_mean,glucose_mean,aniongap,albumin,bicarbonate,bilirubin,creatinine,chloride,glucose,hemoglobin,lactate,magnesium,phosphate,platelet,potassium,ptt,inr,pt,sodium,bun,wbc,most_abnormal,best_feat
0,1.392477,-0.029624,-0.015069,-0.020085,0.026276,0.99616,0.465077,-0.02547,0.139907,-1.179369,0.082431,0.185142,-0.57922,-0.73269,-0.174261,-0.152695,-0.242366,-0.089849,-0.274201,-0.33529,-1.308177,-0.187468,-0.398609,-0.418429,-0.45567,-0.823223,-0.184449,heartrate_mean,heartrate_mean
1,0.80086,-0.216676,0.03125,0.101536,0.026276,0.99616,0.465077,-0.02547,0.139907,-1.179369,0.082431,0.185142,-0.57922,-0.73269,-0.174261,-0.152695,-0.242366,-0.089849,-0.274201,-0.33529,-1.308177,-0.187468,-0.398609,-0.418429,-0.45567,-0.823223,-0.184449,potassium,chloride
2,0.155459,-0.216676,0.03125,0.101536,0.026276,0.99616,0.818195,-0.02547,0.139907,-1.179369,0.082431,0.185142,-0.57922,-0.73269,-0.174261,-0.152695,-0.242366,-0.089849,-0.274201,-0.33529,-1.308177,-0.187468,-0.398609,-0.418429,-0.45567,-0.823223,-0.184449,potassium,chloride
3,-0.274808,-0.908913,-0.365904,-0.482395,0.026276,0.99616,1.171313,-0.02547,0.139907,-1.179369,0.082431,0.185142,-0.57922,-0.73269,-0.174261,-0.152695,-0.242366,-0.089849,-0.274201,-0.33529,-1.308177,-0.187468,-0.398609,-0.418429,-0.45567,-0.823223,-0.184449,potassium,sysbp_mean
4,-0.597508,-0.908913,-0.101135,-0.309383,0.026276,0.99616,1.171313,0.676408,0.139907,-1.179369,0.082431,0.185142,-0.57922,-0.73269,-0.174261,-0.152695,-0.242366,-0.089849,-0.274201,-0.33529,-1.308177,-0.187468,-0.398609,-0.418429,-0.45567,-0.823223,-0.184449,potassium,sysbp_mean


In [12]:
drivers = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 
           'meanbp_mean', 'resprate_mean', 'spo2_mean']

In [23]:
def expected(x, drivers, sigmas=0.5, threshold=0.5):
    orig_prob_ensemble = x['orig_prob_ensemble']
    best_feat = x['best_feat']
    label = x['label']
    # Unexpected: ensemble_probability greater than 0.5, 
    # all drivers normal, best feature is not a driver 
    # patient ended up being ventillated
    cond1 = orig_prob_ensemble > threshold
    cond2 = np.all(abs(x[drivers]) <= sigmas)
    cond3 = np.isin(best_feat, drivers)
    cond4 = label == 1
    cond5 = x[best_feat] > sigmas
    if cond1 and cond2 and (not cond3) and cond4:
        return 'unexpexted_ventilated'
    # Expected: Best feature is a driver with an abnormal value, patient ended up being ventillated
    if cond3 and cond4 and cond5:
        return 'expected_ventilated'
    elif cond4:
        return 'other_ventilated'
    # Not Ventilated: Expected (prob<=thres) AND (drivers are normal) AND (label==0)
    elif (not cond1) and cond2 and (not cond4):
        return 'expected_notventilated'
    # Not Ventilated: UnExpected (prob<=thres) AND (one or more drivers are abnormal) (label==0)
    elif (not cond1) and (not cond2) and (not cond4):
        return 'unexpected_notventilated'
    elif (not cond4):
        return 'other_notventilated'
    else:
        return 'other'

In [24]:
joined['stats'] = joined.apply(lambda x: expected(x, drivers), axis=1)

In [25]:
joined['stats'].value_counts()

unexpected_notventilated    32871
other_ventilated             2701
expected_ventilated          2610
other_notventilated          2130
expected_notventilated        485
unexpexted_ventilated         232
Name: stats, dtype: int64

In [26]:
joined[['case','stats']].groupby('stats')['case'].nunique()

stats
expected_notventilated       312
expected_ventilated          248
other_notventilated          540
other_ventilated             250
unexpected_notventilated    1645
unexpexted_ventilated        143
Name: case, dtype: int64

In [71]:
248+1666+250+143

2307

In [72]:
joined['case'].nunique()

2083

In [67]:
np.all(abs(joined.iloc[1][drivers]) < 0.5)

False

In [56]:
joined.iloc[1][drivers]

heartrate_mean      0.80086
sysbp_mean        -0.216676
diasbp_mean         0.03125
meanbp_mean        0.101536
resprate_mean     0.0262763
spo2_mean          0.465077
Name: 1, dtype: object

In [12]:
def anomalies(x, drivers, sigmas=0.5, threshold=0.5):
    orig_prob_ensemble = x['orig_prob_ensemble']
    if orig_prob_ensemble < threshold:
        return False
    else:
        for driver in drivers:
            if abs(x[driver]) > sigmas:
                return False
    return True

In [13]:
joined['anomaly'] = joined.apply(lambda x: anomalies(x, drivers), axis=1)

In [14]:
np.sum(joined['anomaly'] == True), np.sum(joined['anomaly'] == False)

(345, 40684)

In [50]:
joined[(joined['anomaly'] == False) & 
       (np.isin(joined['best_feat'], drivers))][drivers+['best_feat']].head()

Unnamed: 0,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,spo2_mean,best_feat
0,1.392477,-0.029624,-0.015069,-0.020085,0.026276,0.465077,heartrate_mean
3,-0.274808,-0.908913,-0.365904,-0.482395,0.026276,1.171313,sysbp_mean
4,-0.597508,-0.908913,-0.101135,-0.309383,0.026276,1.171313,sysbp_mean
5,-0.167241,-0.908913,-0.101135,-0.309383,0.026276,-0.594277,sysbp_mean
6,-0.221025,-1.09351,-0.498289,-0.655414,-0.304449,1.171313,sysbp_mean


In [51]:
joined[(joined['label'] == 1) & 
       (np.isin(joined['best_feat'], drivers)) & 
       (abs(joined[joined['best_feat']]) < 0.5)
       ][drivers+['orig_prob_ensemble',
                                         'most_abnormal',
                                         'best_feat', 
                                         'new_risk',
                                         'label']].sort_values(['orig_prob_ensemble'], ascending=False).head(20)

ValueError: cannot reindex from a duplicate axis

In [47]:
joined[joined['anomaly'] == True][features+['orig_prob_ensemble','most_abnormal','best_feat','label']].head()

Unnamed: 0,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,tempc_mean,spo2_mean,glucose_mean,aniongap,albumin,bicarbonate,bilirubin,creatinine,chloride,glucose,hemoglobin,lactate,magnesium,phosphate,platelet,potassium,ptt,inr,pt,sodium,bun,wbc,orig_prob_ensemble,most_abnormal,best_feat,label
169,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,0.344632,-0.005482,-0.02547,-0.663023,-0.010634,0.491466,-0.034093,-0.469036,-0.130071,-0.472639,-0.275809,-0.04501,0.285552,0.391555,0.155205,-0.026918,-0.413027,-0.327233,-0.329783,-0.27871,-0.781833,-0.274317,0.723658,bun,aniongap,0
263,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.051128,-0.005482,-0.811192,-0.863756,0.653979,-0.326604,-0.333316,-0.358853,0.321894,-0.61354,0.503913,5.912063,-0.022061,0.004589,-0.680169,-0.454004,0.245606,-0.327233,-0.329783,-0.63263,-0.285146,-0.476519,0.514972,lactate,aniongap,1
267,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.051128,-0.005482,3.488996,-0.863756,0.653979,-0.326604,-0.333316,-0.358853,0.321894,-0.61354,0.503913,5.912063,-0.022061,0.004589,-0.680169,-0.454004,0.245606,-0.327233,-0.329783,-0.63263,-0.285146,-0.476519,0.544689,lactate,glucose_mean,1
268,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.051128,-0.005482,2.504289,-0.863756,0.653979,-0.326604,-0.333316,-0.358853,0.321894,-0.61354,0.503913,5.912063,-0.022061,0.004589,-0.680169,-0.454004,0.245606,-0.327233,-0.329783,-0.63263,-0.285146,-0.476519,0.513879,lactate,glucose_mean,1
436,-0.005891,-0.308974,0.296019,0.231298,-0.121473,-0.144011,0.111959,-0.02547,-0.663023,0.857684,-0.122086,-0.363814,-0.469036,0.321894,0.25673,0.544951,-0.04501,-0.089849,-0.939957,0.423444,-1.201405,3.137274,-0.18448,-0.252218,-0.10175,-0.61627,0.18251,0.507273,ptt,ptt,0


In [16]:
joined[(joined['anomaly'] == True) & (joined['label'] == 1)]['case'].nunique()

144

In [17]:
len(joined[(joined['anomaly'] == True) & (joined['label'] == 1)])

233

In [44]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)].best_feat.value_counts()

lactate         47
aniongap        47
glucose_mean    30
bilirubin       26
phosphate       21
pt               9
wbc              9
albumin          9
ptt              9
hemoglobin       6
platelet         6
glucose          4
chloride         4
inr              3
sodium           1
magnesium        1
sysbp_mean       1
Name: best_feat, dtype: int64

In [42]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1) & 
       (np.isin(joined['best_feat'], drivers))][drivers+['orig_prob_ensemble',
                                         'most_abnormal',
                                         'best_feat', 
                                         'new_risk',
                                         'label']].sort_values(['orig_prob_ensemble'], ascending=False).head(20)

Unnamed: 0,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,spo2_mean,orig_prob_ensemble,most_abnormal,best_feat,new_risk,label
7354,-0.382375,-0.493571,0.03125,-0.028226,0.061503,0.111959,0.516373,phosphate,sysbp_mean,0.446265,1


In [41]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)][drivers+['orig_prob_ensemble',
                                         'most_abnormal',
                                         'best_feat',
                                        'new_risk',
                                         'label']].sort_values(['orig_prob_ensemble'], ascending=False).head(20)

Unnamed: 0,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,spo2_mean,orig_prob_ensemble,most_abnormal,best_feat,new_risk,label
30905,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.91525,ptt,ptt,0.866585,1
37836,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.893947,albumin,aniongap,0.851056,1
34653,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.893794,ptt,ptt,0.847646,1
14248,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.89285,glucose_mean,glucose_mean,0.870049,1
14249,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.889763,ptt,glucose_mean,0.862005,1
37835,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.887185,albumin,aniongap,0.841551,1
7090,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.886232,ptt,aniongap,0.830021,1
37837,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.876259,albumin,aniongap,0.831824,1
37849,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.866486,chloride,lactate,0.797816,1
16033,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.862369,potassium,bilirubin,0.844693,1


In [35]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)][drivers+['orig_prob_ensemble',
                                        'most_abnormal',
                                        'best_feat', 
                                        'new_risk',
                                        'label']].sort_values(['new_risk']).head(20)

Unnamed: 0,heartrate_mean,sysbp_mean,diasbp_mean,meanbp_mean,resprate_mean,spo2_mean,orig_prob_ensemble,most_abnormal,best_feat,new_risk,label
8281,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.524423,bun,phosphate,0.386701,1
268,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.513879,lactate,glucose_mean,0.390706,1
267,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.544689,lactate,glucose_mean,0.400683,1
8280,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.545707,bun,phosphate,0.407059,1
12575,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.520785,platelet,bilirubin,0.425828,1
19429,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.644209,albumin,lactate,0.432681,1
10065,-0.436158,0.106368,-0.101135,0.166417,0.24448,0.465077,0.583278,phosphate,phosphate,0.44258,1
15510,-0.005891,0.383263,-0.432096,0.05828,0.427456,-0.241159,0.583117,ptt,albumin,0.445719,1
7354,-0.382375,-0.493571,0.03125,-0.028226,0.061503,0.111959,0.516373,phosphate,sysbp_mean,0.446265,1
28521,-0.013136,-0.029624,-0.015069,-0.020085,0.026276,-0.005482,0.71712,bilirubin,bilirubin,0.4471,1


In [32]:
index = 37836
case = joined.loc[index].case
timepoint = joined.loc[268].timepoint
best_feat = joined.loc[index].best_feat
abnormal = joined.loc[index].most_abnormal

In [33]:
df_cohort[(df_cohort.subject_id==case) & (df_cohort.timepoint==timepoint)][[best_feat, abnormal]]

Unnamed: 0,aniongap,albumin


In [34]:
mimic.print_notes(df_notes, case)




In [22]:
mimic.print_notes?

In [None]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)]['abnormal'].value_counts()

In [None]:
joined[(joined['anomaly'] == True) & 
       (joined['label'] == 1)]['best_feat'].value_counts()

In [None]:
df_notes = mimic.get_cohort_notes()

## Scratchpad

In [None]:
rbos = pd.read_json('mimic_rbos_valid.json')

In [None]:
rbos[['lstm_lr', 'mlp_lr', 'mlp_lstm', 'rf_lr', 'rf_lstm', 'rf_mlp']].mean()

In [None]:
rbos[['lstm_lr', 'mlp_lr', 'mlp_lstm', 'rf_lr', 'rf_lstm', 'rf_mlp']].std()

In [None]:
joined = pd.read_csv('time_mimic_magecs.csv')
prob_cols = [c for c in joined.columns if c.startswith('perturb') and 'resprate_mean' not in c]
joined[['best_feat', 'new_risk']] = joined.apply(lambda x: mimic.best_feature(x, prob_cols), axis=1)

In [None]:
joined['orig_prob_ensemble'] = joined[['orig_prob_mlp', 'orig_prob_lr', 
                                       'orig_prob_rf', 'orig_prob_lstm']].apply(np.mean, 1)

In [None]:
vitals = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 'meanbp_mean',
          'resprate_mean', 'tempc_mean', 'spo2_mean', 'glucose_mean']

labs = ['aniongap', 'albumin', 'bicarbonate', 'bilirubin', 'creatinine', 
        'chloride', 'glucose', 'hemoglobin', 'lactate', 
        'magnesium', 'phosphate', 'platelet', 'potassium', 'ptt', 'inr', 
        'pt', 'sodium', 'bun', 'wbc']  # -hematocrit

comobs = ['congestive_heart_failure', 'chronic_pulmonary', 'pulmonary_circulation']

others = ['age', 'gender']

features = vitals+labs

In [None]:
def most_abnormal(x, features):
    res = None
    feat = None
    for f in features:
        if f == 'resprate_mean':
            continue
        if res is None or abs(x[f]) > res:
            res = abs(x[f])
            feat = f
    return feat

In [None]:
joined['abnormal'] = joined.apply(lambda x: most_abnormal(x, features), axis=1)

In [None]:
joined[features+['abnormal','best_feat']].head()

In [None]:
np.sum(joined.best_feat=='resprate_mean'), np.sum(joined.abnormal=='resprate_mean')

In [None]:
np.sum(joined['abnormal'] == joined['best_feat']), np.sum(joined['abnormal'] != joined['best_feat'])

In [None]:
7736 / (7736+33293), (7736+33293) == len(joined)

In [None]:
len(joined) == 8310 + 32719

In [None]:
8310 / len(joined)

In [None]:
print(np.sum((joined.new_risk < 0.5) & 
             (joined.best_feat != 'resprate_mean') & 
             (joined.abnormal != 'resprate_mean')
             (joined.orig_prob_ensemble > 0.5)))

In [None]:
print(np.sum((joined.new_risk < 0.5) & 
             (joined.best_feat != 'resprate_mean') & 
             (joined.orig_prob_ensemble > 0.5) & 
             (joined.abnormal != joined.best_feat)))

In [None]:
906 / 1425

In [None]:
print(np.sum((joined.label==1)& (joined.orig_prob_ensemble > 0.5))) 
print(np.sum((joined.label==1) & 
             (joined.new_risk < 0.5) & 
             (joined.best_feat != 'resprate_mean') & 
             (joined.orig_prob_ensemble > 0.5)))

In [None]:
554 / 1540

In [None]:
joined[joined.label == 1]['best_feat'].value_counts()

In [None]:
index = 31760
fig1, ax1 = mimic.best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure');

In [None]:
joined.loc[index][['case','timepoint','abnormal','best_feat','orig_prob_ensemble']]

In [None]:
df_series_valid.loc[joined.loc[index]['case'], joined.loc[index]['timepoint']].sort_values()

In [None]:
index = 10406
fig2, ax2 = mimic.best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure');

In [None]:
joined.loc[index][['case','timepoint','abnormal','best_feat','orig_prob_ensemble']]

In [None]:
joined[(joined.case==15396) & 
       (joined.timepoint==11)][['case','timepoint','abnormal','best_feat','orig_prob_ensemble']]

In [None]:
df_cohort[(df_cohort.subject_id==joined.loc[index]['case']) & 
          (df_cohort.timepoint==11)][features]

In [None]:
series_means['sysbp_mean']

In [None]:
df_series_valid.loc[joined.loc[index]['case'], joined.loc[index]['timepoint']].sort_values()

In [None]:
df_time[(df_time.subject_id==joined.loc[index]['case']) & 
          (df_time.timepoint==joined.loc[index]['timepoint'])]

In [None]:
df_cohort = mimic.get_mimic_data()

df_ml = mimic.get_ml_data(df_cohort)

df_time = mimic.get_ml_series_data(df_cohort)

_, x_validation, stsc, _, xst_validation, _, Y_validation = mimic.train_valid_ml(df_ml)

stsc2, series_means, _, df_series_valid, _, _, xt_valid, Yt_valid = mimic.train_valid_series(df_time, Y_validation)

In [None]:
df_series_valid.loc[joined.loc[index]['case'], joined.loc[index]['timepoint']].sort_values()

In [None]:
len(df_cohort)

In [None]:
vitals = ['heartrate_mean', 'sysbp_mean', 'diasbp_mean', 'meanbp_mean',
          'resprate_mean', 'tempc_mean', 'spo2_mean', 'glucose_mean']

labs = ['aniongap', 'albumin', 'bicarbonate', 'bilirubin', 'creatinine', 
        'chloride', 'glucose', 'hemoglobin', 'lactate', 
        'magnesium', 'phosphate', 'platelet', 'potassium', 'ptt', 'inr', 
        'pt', 'sodium', 'bun', 'wbc']  # -hematocrit

comobs = ['congestive_heart_failure', 'chronic_pulmonary', 'pulmonary_circulation']

others = ['age', 'gender']

features = vitals+labs

In [None]:
def get_magecs(p=None, models=('lr', 'mlp', 'svm', 'lstm')):
    prefix = 'magec_'
    postfix = '_0.csv' if p is None else '_p' + str(p) + '.csv'
    magecs = []
    for m in models:
        filename = prefix + m + postfix
        magecs.append(pd.read_csv(filename))
    return magecs

In [None]:
x_magec_cols = list(set(df_series_valid.columns) - {'label'})
x_magec = df_series_valid[x_magec_cols]
y_magec = df_series_valid['label']

In [None]:
magecs = get_magecs()

In [None]:
joined = mg.magec_models(*magecs, Xdata=x_magec, Ydata=y_magec, features=features)

In [None]:
prob_cols = [c for c in joined.columns if c.startswith('perturb')]
joined['orig_prob_ensemble'] = joined[['orig_prob_mlp', 'orig_prob_lr', 
                                       'orig_prob_svm', 'orig_prob_lstm']].apply(np.mean, 1)
joined[['best_feat', 'new_risk']] = joined.apply(lambda x: mimic.best_feature(x, prob_cols), axis=1)

In [None]:
len(joined)

In [None]:
joined.loc[31760]

In [None]:
index = 31760
fig1, ax1 = mimic.best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure');

In [None]:
timepoint = 14
txt = 'Clinical values and ranked MAgECs at time point {}'
fig  = mg.panel_plot(xst_validation.columns, features, stsc2, joined, 
                                                      joined.loc[index].case, timepoint, 
                     models=('lr','svm','mlp','lstm', 'ensemble'), label='label', limit=6, rotate=25, 
                  title=txt.format(timepoint))

In [None]:
index = 10406
fig2, ax2 = mimic.best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure');

In [None]:
from scipy import interpolate
from adjustText import adjust_text

def best_feat_plot(joined, cohort, index, title='', save=False, feat=None):
    data = joined.loc[index]
    case, t_0, label, orig_prob, new_risk = data[['case', 'timepoint', 'label',
                                                  'orig_prob_ensemble', 'new_risk']]
    if feat is None:
        best_feat = joined.loc[index]['best_feat']
    else:
        best_feat = feat

    xy = cohort[cohort['subject_id'] == case][['timepoint', best_feat]].values
    x = [int(x[0]) for x in xy]
    yy = [x[1] for x in xy]

    fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(16, 6))

    mimic.plot_feature(ax[0], x, yy, best_feat, case, label, title=title)

    xyzw = joined[joined.case == case][['timepoint', 'orig_prob_ensemble', 'best_feat', 'new_risk']].values
    x = [int(x[0]) for x in xyzw]
    y = [x[1] for x in xyzw]
    z = [x[2] for x in xyzw]
    w = [x[3] for x in xyzw]
    zz = [cohort[(cohort['subject_id'] == case) & (cohort['timepoint'] == x[i])][feat].values[0]
          for i, feat in enumerate(z)]

    plot_risk(ax[1], x, y, z, w, zz, case, label, feat=feat)
    ax[1].invert_xaxis()

    if save:
        plt.savefig('case_{}_series.png'.format(case))

    return fig, ax

def plot_risk(ax, x, y, z, w, yy, case, label, feat=None):
    ax.plot(x, y, 'rx--')
    ax.plot(np.linspace(ax.get_xlim()[0], ax.get_xlim()[1], 10), 0.5 * np.ones(10), '--')
    txt = 'Case {}: Hourly Estimated Ensemble Risk (Outcome: {})'
    ax.set_title(txt.format(case, label))
    ax.set_ylabel('Ensemble Risk')
    ax.set_xlabel('hours to event')
    ax.grid('on')
    # ax.set_ylim([0.2, 0.9])

    texts = []

    for i, txt in enumerate(z):
        if np.isnan(yy[i]):
            continue
        if feat is not None and feat == txt:    
            msg = txt + ' = {:.2f}\n perturb. risk = {:.2g}'.format(yy[i], w[i])
            texts.append(ax.text(x[i], y[i], msg))
        elif feat is None and (w[i] < 0.5 < y[i]):
            msg = txt + ' = {:.2f}\n perturb. risk = {:.2g}'.format(yy[i], w[i])
            texts.append(ax.text(x[i], y[i], msg))

    if feat is None:
        f = interpolate.interp1d(x, y)
        x = np.linspace(min(x), max(x), 140)
        y = f(x)
        adjust_text(texts, x, y, arrowprops=dict(arrowstyle="->", color='b', lw=0.5), autoalign='xy')
    return

In [None]:
index = 10406
fig2, ax2 = best_feat_plot(joined, df_cohort, index, title='for Mean Arterial Pressure', feat='meanbp_mean');